You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import matplotlib.pyplot as plt
  2. import torch
  3. import os
  4. import gc
  5. from datetime import datetime
  6. from pytorch_adapt.datasets import DataloaderCreator, get_office31
  7. from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite
  8. from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory, DiversityValidator, EntropyValidator, MultipleValidators
  9. from models import get_model
  10. from utils import DAModels
  11. from vis_hook import VizHook
  12. def train(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain):
  13. if args.source != None and args.source != source_domain:
  14. return None, None, None
  15. if args.target != None and args.target != target_domain:
  16. return None, None, None
  17. pair_name = f"{source_domain[0]}2{target_domain[0]}"
  18. output_dir = os.path.join(base_output_dir, pair_name)
  19. os.makedirs(output_dir, exist_ok=True)
  20. print("output dir:", output_dir)
  21. datasets = get_office31([source_domain], [target_domain],
  22. folder=args.data_root,
  23. return_target_with_labels=True,
  24. download=args.download)
  25. dc = DataloaderCreator(batch_size=args.batch_size,
  26. num_workers=args.num_workers,
  27. train_names=["train"],
  28. val_names=["src_train", "target_train", "src_val", "target_val",
  29. "target_train_with_labels", "target_val_with_labels"])
  30. source_checkpoint_dir = None if not args.source_checkpoint_base_dir else \
  31. f"{args.source_checkpoint_base_dir}/{args.source_checkpoint_trial_number}/{pair_name}/saved_models"
  32. print("source_checkpoint_dir", source_checkpoint_dir)
  33. adapter = get_model(model_name, hp, source_checkpoint_dir, args.data_root, source_domain)
  34. checkpoint_fn = CheckpointFnCreator(dirname=f"{output_dir}/saved_models", require_empty=False)
  35. sourceAccuracyValidator = AccuracyValidator()
  36. validators = {
  37. "entropy": EntropyValidator(),
  38. "diversity": DiversityValidator(),
  39. }
  40. validator = ScoreHistory(MultipleValidators(validators))
  41. targetAccuracyValidator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"})
  42. val_hooks = [ScoreHistory(sourceAccuracyValidator),
  43. ScoreHistory(targetAccuracyValidator),
  44. VizHook(output_dir=output_dir, frequency=args.vishook_frequency)]
  45. trainer = Ignite(
  46. adapter, validator=validator, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn
  47. )
  48. early_stopper_kwargs = {"patience": args.patience}
  49. start_time = datetime.now()
  50. best_score, best_epoch = trainer.run(
  51. datasets, dataloader_creator=dc, max_epochs=args.max_epochs, early_stopper_kwargs=early_stopper_kwargs
  52. )
  53. end_time = datetime.now()
  54. training_time = end_time - start_time
  55. print(f"best_score={best_score}, best_epoch={best_epoch}")
  56. plt.plot(val_hooks[0].score_history, label='source')
  57. plt.plot(val_hooks[1].score_history, label='target')
  58. plt.title("val accuracy")
  59. plt.legend()
  60. plt.savefig(f"{output_dir}/val_accuracy.png")
  61. plt.close('all')
  62. plt.plot(validator.score_history)
  63. plt.title("score_history")
  64. plt.savefig(f"{output_dir}/score_history.png")
  65. plt.close('all')
  66. src_score = trainer.evaluate_best_model(datasets, sourceAccuracyValidator, dc)
  67. print("Source acc:", src_score)
  68. target_score = trainer.evaluate_best_model(datasets, targetAccuracyValidator, dc)
  69. print("Target acc:", target_score)
  70. print("---------")
  71. if args.hp_tune:
  72. with open(results_file, "a") as myfile:
  73. myfile.write(f"{hp.lr}, {hp.gamma}, {pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}\n")
  74. else:
  75. with open(results_file, "a") as myfile:
  76. myfile.write(
  77. f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {best_score}, {training_time.seconds}, {hp.lr}, {hp.gamma},")
  78. del adapter
  79. gc.collect()
  80. torch.cuda.empty_cache()
  81. return src_score, target_score, best_score