You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import argparse
  2. import logging
  3. import os
  4. from datetime import datetime
  5. from train import train
  6. from train_source import train_source
  7. from utils import HP, DAModels
  8. import tracemalloc
  9. logging.basicConfig()
  10. logging.getLogger("pytorch-adapt").setLevel(logging.WARNING)
  11. DATASET_PAIRS = [("amazon", "webcam"), ("amazon", "dslr"),
  12. ("webcam", "amazon"), ("webcam", "dslr"),
  13. ("dslr", "amazon"), ("dslr", "webcam")]
  14. def run_experiment_on_model(args, model_name, trial_number):
  15. train_fn = train
  16. if model_name == DAModels.SOURCE:
  17. train_fn = train_source
  18. base_output_dir = f"{args.results_root}/{model_name}/{trial_number}"
  19. os.makedirs(base_output_dir, exist_ok=True)
  20. hp_map = {
  21. 'DANN': {
  22. 'a2d': (5e-05, 0.99),
  23. 'a2w': (5e-05, 0.99),
  24. 'd2a': (0.0001, 0.9),
  25. 'd2w': (5e-05, 0.99),
  26. 'w2a': (0.0001, 0.99),
  27. 'w2d': (0.0001, 0.99),
  28. },
  29. 'CDAN': {
  30. 'a2d': (1e-05, 1),
  31. 'a2w': (1e-05, 1),
  32. 'd2a': (1e-05, 0.99),
  33. 'd2w': (1e-05, 0.99),
  34. 'w2a': (0.0001, 0.99),
  35. 'w2d': (5e-05, 0.99),
  36. },
  37. 'MMD': {
  38. 'a2d': (5e-05, 1),
  39. 'a2w': (5e-05, 0.99),
  40. 'd2a': (0.0001, 0.99),
  41. 'd2w': (5e-05, 0.9),
  42. 'w2a': (0.0001, 0.99),
  43. 'w2d': (1e-05, 0.99),
  44. },
  45. 'MCD': {
  46. 'a2d': (1e-05, 0.9),
  47. 'a2w': (0.0001, 1),
  48. 'd2a': (1e-05, 0.9),
  49. 'd2w': (1e-05, 0.99),
  50. 'w2a': (1e-05, 0.9),
  51. 'w2d': (5*1e-6, 0.99),
  52. },
  53. 'CORAL': {
  54. 'a2d': (1e-05, 0.99),
  55. 'a2w': (1e-05, 1),
  56. 'd2a': (5*1e-6, 0.99),
  57. 'd2w': (0.0001, 0.99),
  58. 'w2a': (1e-5, 0.99),
  59. 'w2d': (0.0001, 0.99),
  60. },
  61. }
  62. if not args.hp_tune:
  63. d = datetime.now()
  64. results_file = f"{base_output_dir}/e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt"
  65. with open(results_file, "w") as myfile:
  66. myfile.write("pair, source_acc, target_acc, best_epoch, best_score, time, cur, peak, lr, gamma\n")
  67. for source_domain, target_domain in DATASET_PAIRS:
  68. pair_name = f"{source_domain[0]}2{target_domain[0]}"
  69. hp_parmas = hp_map[DAModels.CDAN.name][pair_name]
  70. lr = args.lr if args.lr else hp_parmas[0]
  71. gamma = args.gamma if args.gamma else hp_parmas[1]
  72. hp = HP(lr=lr, gamma=gamma)
  73. tracemalloc.start()
  74. train_fn(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain)
  75. cur, peak = tracemalloc.get_traced_memory()
  76. tracemalloc.stop()
  77. with open(results_file, "a") as myfile:
  78. myfile.write(f"{cur}, {peak}\n")
  79. else:
  80. # gamma_list = [1, 0.99, 0.9, 0.8]
  81. # lr_list = [1e-5, 3*1e-5, 1e-4, 3*1e-4]
  82. hp_values = {
  83. 1e-4: [0.99, 0.9],
  84. 1e-5: [0.99, 0.9],
  85. 5*1e-5: [0.99, 0.9],
  86. 5*1e-6: [0.99]
  87. # 5*1e-4: [1, 0.99],
  88. # 1e-3: [0.99, 0.9, 0.8],
  89. }
  90. d = datetime.now()
  91. hp_file = f"{base_output_dir}/hp_e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt"
  92. with open(hp_file, "w") as myfile:
  93. myfile.write("lr, gamma, pair, source_acc, target_acc, best_epoch, best_score\n")
  94. hp_best = None
  95. hp_best_score = None
  96. for lr, gamma_list in hp_values.items():
  97. # for lr in lr_list:
  98. for gamma in gamma_list:
  99. hp = HP(lr=lr, gamma=gamma)
  100. print("HP:", hp)
  101. for source_domain, target_domain in DATASET_PAIRS:
  102. _, _, best_score = \
  103. train_fn(args, model_name, hp, base_output_dir, hp_file, source_domain, target_domain)
  104. if best_score is not None and (hp_best_score is None or hp_best_score < best_score):
  105. hp_best = hp
  106. hp_best_score = best_score
  107. with open(hp_file, "a") as myfile:
  108. myfile.write(f"\nbest: {hp_best.lr}, {hp_best.gamma}\n")
  109. if __name__ == "__main__":
  110. parser = argparse.ArgumentParser()
  111. parser.add_argument('--max_epochs', default=60, type=int)
  112. parser.add_argument('--patience', default=10, type=int)
  113. parser.add_argument('--batch_size', default=32, type=int)
  114. parser.add_argument('--num_workers', default=2, type=int)
  115. parser.add_argument('--trials_count', default=3, type=int)
  116. parser.add_argument('--initial_trial', default=0, type=int)
  117. parser.add_argument('--download', default=False, type=bool)
  118. parser.add_argument('--root', default="../")
  119. parser.add_argument('--data_root', default="../datasets/pytorch-adapt/")
  120. parser.add_argument('--results_root', default="../results/")
  121. parser.add_argument('--model_names', default=["DANN"], nargs='+')
  122. parser.add_argument('--lr', default=None, type=float)
  123. parser.add_argument('--gamma', default=None, type=float)
  124. parser.add_argument('--hp_tune', default=False, type=bool)
  125. parser.add_argument('--source', default=None)
  126. parser.add_argument('--target', default=None)
  127. parser.add_argument('--vishook_frequency', default=5, type=int)
  128. parser.add_argument('--source_checkpoint_base_dir', default=None) # default='../results/DAModels.SOURCE/'
  129. parser.add_argument('--source_checkpoint_trial_number', default=-1, type=int)
  130. args = parser.parse_args()
  131. print(args)
  132. for trial_number in range(args.initial_trial, args.initial_trial + args.trials_count):
  133. if args.source_checkpoint_trial_number == -1:
  134. args.source_checkpoint_trial_number = trial_number
  135. for model_name in args.model_names:
  136. try:
  137. model_enum = DAModels(model_name)
  138. except ValueError:
  139. logging.warning(f"Model {model_name} not found. skipping...")
  140. continue
  141. run_experiment_on_model(args, model_enum, trial_number)