You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 3.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import torch
  2. import os
  3. from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets
  4. from pytorch_adapt.containers import Models, Optimizers, LRSchedulers
  5. from pytorch_adapt.models import Discriminator, office31C, office31G
  6. from pytorch_adapt.containers import Misc
  7. from pytorch_adapt.layers import RandomizedDotProduct
  8. from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss
  9. from pytorch_adapt.utils import common_functions
  10. from pytorch_adapt.containers import LRSchedulers
  11. from classifier_adapter import ClassifierAdapter
  12. from load_source import get_source_trainer
  13. from utils import HP, DAModels
  14. import copy
  15. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16. def get_model(model_name, hp: HP, source_checkpoint_dir, data_root, source_domain):
  17. if source_checkpoint_dir:
  18. source_trainer = get_source_trainer(source_checkpoint_dir)
  19. G = copy.deepcopy(source_trainer.adapter.models["G"])
  20. C = copy.deepcopy(source_trainer.adapter.models["C"])
  21. D = Discriminator(in_size=2048, h=1024).to(device)
  22. else:
  23. weights_root = os.path.join(data_root, "weights")
  24. G = office31G(pretrained=True, model_dir=weights_root).to(device)
  25. C = office31C(domain=source_domain, pretrained=True,
  26. model_dir=weights_root).to(device)
  27. D = Discriminator(in_size=2048, h=1024).to(device)
  28. optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr}))
  29. lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma}))
  30. if model_name == DAModels.DANN:
  31. models = Models({"G": G, "C": C, "D": D})
  32. adapter = DANN(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  33. elif model_name == DAModels.CDAN:
  34. models = Models({"G": G, "C": C, "D": D})
  35. misc = Misc({"feature_combiner": RandomizedDotProduct([2048, 31], 2048)})
  36. adapter = CDAN(models=models, misc=misc, optimizers=optimizers, lr_schedulers=lr_schedulers)
  37. elif model_name == DAModels.MCD:
  38. C1 = common_functions.reinit(copy.deepcopy(C))
  39. C_combined = MultipleModels(C, C1)
  40. models = Models({"G": G, "C": C_combined})
  41. adapter= MCD(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  42. elif model_name == DAModels.SYMNET:
  43. C1 = common_functions.reinit(copy.deepcopy(C))
  44. C_combined = MultipleModels(C, C1)
  45. models = Models({"G": G, "C": C_combined})
  46. adapter= SymNets(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  47. elif model_name == DAModels.MMD:
  48. models = Models({"G": G, "C": C})
  49. hook_kwargs = {"loss_fn": MMDLoss()}
  50. adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs)
  51. elif model_name == DAModels.CORAL:
  52. models = Models({"G": G, "C": C})
  53. hook_kwargs = {"loss_fn": CORALLoss()}
  54. adapter= Aligner(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers, hook_kwargs=hook_kwargs)
  55. elif model_name == DAModels.SOURCE:
  56. models = Models({"G": G, "C": C})
  57. adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  58. return adapter