You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

source.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import logging
  2. import matplotlib.pyplot as plt
  3. import pandas as pd
  4. import seaborn as sns
  5. import torch
  6. import umap
  7. from tqdm import tqdm
  8. import os
  9. import gc
  10. from datetime import datetime
  11. from pytorch_adapt.adapters import DANN, MCD, VADA, CDAN, RTN, ADDA, Aligner
  12. from pytorch_adapt.adapters.base_adapter import BaseGCAdapter
  13. from pytorch_adapt.adapters.utils import with_opt
  14. from pytorch_adapt.hooks import ClassifierHook
  15. from pytorch_adapt.containers import Models, Optimizers, LRSchedulers
  16. from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, get_office31
  17. from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite
  18. from pytorch_adapt.models import Discriminator, mnistC, mnistG, office31C, office31G
  19. from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory
  20. from pytorch_adapt.containers import Misc
  21. from pytorch_adapt.layers import RandomizedDotProduct
  22. from pytorch_adapt.layers import MultipleModels
  23. from pytorch_adapt.utils import common_functions
  24. from pytorch_adapt.containers import LRSchedulers
  25. import copy
  26. from pprint import pprint
  27. PATIENCE = 5
  28. EPOCHS = 50
  29. BATCH_SIZE = 32
  30. NUM_WORKERS = 2
  31. TRIAL_COUNT = 5
  32. logging.basicConfig()
  33. logging.getLogger("pytorch-adapt").setLevel(logging.WARNING)
  34. class ClassifierAdapter(BaseGCAdapter):
  35. """
  36. Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook].
  37. |Container|Required keys|
  38. |---|---|
  39. |models|```["G", "C"]```|
  40. |optimizers|```["G", "C"]```|
  41. """
  42. def init_hook(self, hook_kwargs):
  43. opts = with_opt(list(self.optimizers.keys()))
  44. self.hook = self.hook_cls(opts, **hook_kwargs)
  45. @property
  46. def hook_cls(self):
  47. return ClassifierHook
  48. root='/content/drive/MyDrive/Shared with Sabas/Bsc/'
  49. # root="datasets/pytorch-adapt/"
  50. data_root = os.path.join(root,'data')
  51. batch_size=BATCH_SIZE
  52. num_workers=NUM_WORKERS
  53. device = torch.device("cuda")
  54. model_dir = os.path.join(data_root, "weights")
  55. DATASET_PAIRS = [("amazon", ["webcam", "dslr"]),
  56. ("webcam", ["dslr", "amazon"]),
  57. ("dslr", ["amazon", "webcam"])
  58. ]
  59. MODEL_NAME = "base"
  60. model_name = MODEL_NAME
  61. pass_next= 0
  62. pass_trial = 0
  63. for trial_number in range(10, 10 + TRIAL_COUNT):
  64. if pass_trial:
  65. pass_trial -= 1
  66. continue
  67. base_output_dir = f"{root}/results/vishook/{MODEL_NAME}/{trial_number}"
  68. os.makedirs(base_output_dir, exist_ok=True)
  69. d = datetime.now()
  70. results_file = f"{base_output_dir}/{d.strftime('%Y%m%d-%H:%M:%S')}.txt"
  71. with open(results_file, "w") as myfile:
  72. myfile.write("pair, source_acc, target_acc, best_epoch, time\n")
  73. for source_domain, target_domains in DATASET_PAIRS:
  74. datasets = get_office31([source_domain], [], folder=data_root)
  75. dc = DataloaderCreator(batch_size=batch_size,
  76. num_workers=num_workers,
  77. )
  78. dataloaders = dc(**datasets)
  79. G = office31G(pretrained=True, model_dir=model_dir)
  80. C = office31C(domain=source_domain, pretrained=True, model_dir=model_dir)
  81. optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0005}))
  82. lr_schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}))
  83. if model_name == "base":
  84. models = Models({"G": G, "C": C})
  85. adapter= ClassifierAdapter(models=models, optimizers=optimizers, lr_schedulers=lr_schedulers)
  86. checkpoint_fn = CheckpointFnCreator(dirname="saved_models", require_empty=False)
  87. val_hooks = [ScoreHistory(AccuracyValidator())]
  88. trainer = Ignite(
  89. adapter, val_hooks=val_hooks, checkpoint_fn=checkpoint_fn,
  90. )
  91. early_stopper_kwargs = {"patience": PATIENCE}
  92. start_time = datetime.now()
  93. best_score, best_epoch = trainer.run(
  94. datasets, dataloader_creator=dc, max_epochs=EPOCHS, early_stopper_kwargs=early_stopper_kwargs
  95. )
  96. end_time = datetime.now()
  97. training_time = end_time - start_time
  98. for target_domain in target_domains:
  99. if pass_next:
  100. pass_next -= 1
  101. continue
  102. pair_name = f"{source_domain[0]}2{target_domain[0]}"
  103. output_dir = os.path.join(base_output_dir, pair_name)
  104. os.makedirs(output_dir, exist_ok=True)
  105. print("output dir:", output_dir)
  106. # print(f"best_score={best_score}, best_epoch={best_epoch}, training_time={training_time.seconds}")
  107. plt.plot(val_hooks[0].score_history, label='source')
  108. plt.title("val accuracy")
  109. plt.legend()
  110. plt.savefig(f"{output_dir}/val_accuracy.png")
  111. plt.close('all')
  112. datasets = get_office31([source_domain], [target_domain], folder=data_root, return_target_with_labels=True)
  113. dc = DataloaderCreator(batch_size=64, num_workers=2, all_val=True)
  114. validator = AccuracyValidator(key_map={"src_val": "src_val"})
  115. src_score = trainer.evaluate_best_model(datasets, validator, dc)
  116. print("Source acc:", src_score)
  117. validator = AccuracyValidator(key_map={"target_val_with_labels": "src_val"})
  118. target_score = trainer.evaluate_best_model(datasets, validator, dc)
  119. print("Target acc:", target_score)
  120. with open(results_file, "a") as myfile:
  121. myfile.write(f"{pair_name}, {src_score}, {target_score}, {best_epoch}, {training_time.seconds}\n")
  122. del trainer
  123. del G
  124. del C
  125. gc.collect()
  126. torch.cuda.empty_cache()