You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_source.py 1.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import os
  3. from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets
  4. from pytorch_adapt.containers import Models, Optimizers, LRSchedulers
  5. from pytorch_adapt.models import Discriminator, office31C, office31G
  6. from pytorch_adapt.containers import Misc
  7. from pytorch_adapt.containers import LRSchedulers
  8. from classifier_adapter import ClassifierAdapter
  9. from utils import HP, DAModels
  10. import copy
  11. import matplotlib.pyplot as plt
  12. import torch
  13. import os
  14. import gc
  15. from datetime import datetime
  16. from pytorch_adapt.datasets import DataloaderCreator, get_office31
  17. from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite
  18. from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory
  19. from pytorch_adapt.frameworks.ignite import (
  20. CheckpointFnCreator,
  21. IgniteValHookWrapper,
  22. checkpoint_utils,
  23. )
  24. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  25. def get_source_trainer(checkpoint_dir):
  26. G = office31G(pretrained=False).to(device)
  27. C = office31C(pretrained=False).to(device)
  28. optimizers = Optimizers((torch.optim.Adam, {"lr": 1e-4}))
  29. lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}))
  30. models = Models({"G": G, "C": C})
  31. adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  32. checkpoint_fn = CheckpointFnCreator(dirname=checkpoint_dir, require_empty=False)
  33. sourceAccuracyValidator = AccuracyValidator()
  34. val_hooks = [ScoreHistory(sourceAccuracyValidator)]
  35. new_trainer = Ignite(
  36. adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn, device=device
  37. )
  38. objs = [
  39. {
  40. "engine": new_trainer.trainer,
  41. **checkpoint_utils.adapter_to_dict(new_trainer.adapter),
  42. }
  43. ]
  44. for to_load in objs:
  45. checkpoint_fn.load_best_checkpoint(to_load)
  46. return new_trainer