You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_source.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import os
  3. from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner, SymNets
  4. from pytorch_adapt.containers import Models, Optimizers, LRSchedulers
  5. from pytorch_adapt.models import Discriminator, office31C, office31G
  6. from pytorch_adapt.containers import Misc
  7. from pytorch_adapt.layers import RandomizedDotProduct
  8. from pytorch_adapt.layers import MultipleModels, CORALLoss, MMDLoss
  9. from pytorch_adapt.utils import common_functions
  10. from pytorch_adapt.containers import LRSchedulers
  11. from classifier_adapter import ClassifierAdapter
  12. from utils import HP, DAModels
  13. import copy
  14. import matplotlib.pyplot as plt
  15. import torch
  16. import os
  17. import gc
  18. from datetime import datetime
  19. from pytorch_adapt.datasets import DataloaderCreator, get_office31
  20. from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite
  21. from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators
  22. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  23. def train_source(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain):
  24. if args.source != None and args.source != source_domain:
  25. return None, None, None
  26. if args.target != None and args.target != target_domain:
  27. return None, None, None
  28. pair_name = f"{source_domain[0]}2{target_domain[0]}"
  29. output_dir = os.path.join(base_output_dir, pair_name)
  30. os.makedirs(output_dir, exist_ok=True)
  31. print("output dir:", output_dir)
  32. datasets = get_office31([source_domain], [],
  33. folder=args.data_root,
  34. return_target_with_labels=True,
  35. download=args.download)
  36. dc = DataloaderCreator(batch_size=args.batch_size,
  37. num_workers=args.num_workers,
  38. )
  39. weights_root = os.path.join(args.data_root, "weights")
  40. G = office31G(pretrained=True, model_dir=weights_root).to(device)
  41. C = office31C(domain=source_domain, pretrained=True,
  42. model_dir=weights_root).to(device)
  43. optimizers = Optimizers((torch.optim.Adam, {"lr": hp.lr}))
  44. lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": hp.gamma}))
  45. models = Models({"G": G, "C": C})
  46. adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  47. # adapter = get_model(model_name, hp, args.data_root, source_domain)
  48. print("checkpoint dir:", output_dir)
  49. checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False)
  50. sourceAccuracyValidator = AccuracyValidator()
  51. val_hooks = [ScoreHistory(sourceAccuracyValidator)]
  52. trainer = Ignite(
  53. adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn
  54. )
  55. early_stopper_kwargs = {"patience": args.patience}
  56. start_time = datetime.now()
  57. best_score, best_epoch = trainer.run(
  58. datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs
  59. )
  60. end_time = datetime.now()
  61. training_time = end_time - start_time
  62. print(f"best_score={best_score}, best_epoch={best_epoch}")
  63. plt.plot(val_hooks[0].score_history, label='source')
  64. plt.title("val accuracy")
  65. plt.legend()
  66. plt.savefig(f"{output_dir}/val_accuracy.png")
  67. plt.close('all')
  68. datasets = get_office31([source_domain], [target_domain], folder=args.data_root, return_target_with_labels=True)
  69. dc = DataloaderCreator(batch_size=args.batch_size, num_workers=args.num_workers, all_val=True)
  70. validator = AccuracyValidator(key_map={"src_val": "src_val"})
  71. src_score = trainer.evaluate_best_model(datasets, validator, dc)
  72. print("Source acc:", src_score)
  73. validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"})
  74. target_score = trainer.evaluate_best_model(datasets, validator, dc)
  75. print("Target acc:", target_score)
  76. print("---------")
  77. if args.hp_tune:
  78. with open(results_file, "a") as myfile:
  79. myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n")
  80. else:
  81. with open(results_file, "a") as myfile:
  82. myfile.write(
  83. f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, ")
  84. del adapter
  85. gc.collect()
  86. torch.cuda.empty_cache()
  87. return src_score, target_score, best_score