123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- import argparse
- import logging
- import os
- from datetime import datetime
- from train import train
- from train_source import train_source
- from utils import HP, DAModels
- import tracemalloc
-
- logging.basicConfig()
- logging.getLogger("pytorch-adapt").setLevel(logging.WARNING)
-
-
- DATASET_PAIRS = [("amazon", "webcam"), ("amazon", "dslr"),
- ("webcam", "amazon"), ("webcam", "dslr"),
- ("dslr", "amazon"), ("dslr", "webcam")]
-
-
- def run_experiment_on_model(args, model_name, trial_number):
- train_fn = train
- if model_name == DAModels.SOURCE:
- train_fn = train_source
-
- base_output_dir = f"{args.results_root}/{model_name}/{trial_number}"
- os.makedirs(base_output_dir, exist_ok=True)
-
- hp_map = {
- 'DANN': {
- 'a2d': (5e-05, 0.99),
- 'a2w': (5e-05, 0.99),
- 'd2a': (0.0001, 0.9),
- 'd2w': (5e-05, 0.99),
- 'w2a': (0.0001, 0.99),
- 'w2d': (0.0001, 0.99),
- },
- 'CDAN': {
- 'a2d': (1e-05, 1),
- 'a2w': (1e-05, 1),
- 'd2a': (1e-05, 0.99),
- 'd2w': (1e-05, 0.99),
- 'w2a': (0.0001, 0.99),
- 'w2d': (5e-05, 0.99),
- },
- 'MMD': {
- 'a2d': (5e-05, 1),
- 'a2w': (5e-05, 0.99),
- 'd2a': (0.0001, 0.99),
- 'd2w': (5e-05, 0.9),
- 'w2a': (0.0001, 0.99),
- 'w2d': (1e-05, 0.99),
- },
- 'MCD': {
- 'a2d': (1e-05, 0.9),
- 'a2w': (0.0001, 1),
- 'd2a': (1e-05, 0.9),
- 'd2w': (1e-05, 0.99),
- 'w2a': (1e-05, 0.9),
- 'w2d': (5*1e-6, 0.99),
- },
- 'CORAL': {
- 'a2d': (1e-05, 0.99),
- 'a2w': (1e-05, 1),
- 'd2a': (5*1e-6, 0.99),
- 'd2w': (0.0001, 0.99),
- 'w2a': (1e-5, 0.99),
- 'w2d': (0.0001, 0.99),
- },
- }
- if not args.hp_tune:
-
- d = datetime.now()
- results_file = f"{base_output_dir}/e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt"
- with open(results_file, "w") as myfile:
- myfile.write("pair, source_acc, target_acc, best_epoch, best_score, time, cur, peak, lr, gamma\n")
-
- for source_domain, target_domain in DATASET_PAIRS:
- pair_name = f"{source_domain[0]}2{target_domain[0]}"
-
- hp_parmas = hp_map[DAModels.CDAN.name][pair_name]
- lr = args.lr if args.lr else hp_parmas[0]
- gamma = args.gamma if args.gamma else hp_parmas[1]
- hp = HP(lr=lr, gamma=gamma)
-
- tracemalloc.start()
-
- train_fn(args, model_name, hp, base_output_dir, results_file, source_domain, target_domain)
-
- cur, peak = tracemalloc.get_traced_memory()
- tracemalloc.stop()
-
- with open(results_file, "a") as myfile:
- myfile.write(f"{cur}, {peak}\n")
-
-
- else:
- # gamma_list = [1, 0.99, 0.9, 0.8]
- # lr_list = [1e-5, 3*1e-5, 1e-4, 3*1e-4]
- hp_values = {
- 1e-4: [0.99, 0.9],
- 1e-5: [0.99, 0.9],
- 5*1e-5: [0.99, 0.9],
- 5*1e-6: [0.99]
- # 5*1e-4: [1, 0.99],
- # 1e-3: [0.99, 0.9, 0.8],
- }
-
- d = datetime.now()
- hp_file = f"{base_output_dir}/hp_e{args.max_epochs}_p{args.patience}_{d.strftime('%Y%m%d-%H:%M:%S')}.txt"
- with open(hp_file, "w") as myfile:
- myfile.write("lr, gamma, pair, source_acc, target_acc, best_epoch, best_score\n")
-
- hp_best = None
- hp_best_score = None
-
- for lr, gamma_list in hp_values.items():
- # for lr in lr_list:
- for gamma in gamma_list:
- hp = HP(lr=lr, gamma=gamma)
- print("HP:", hp)
-
- for source_domain, target_domain in DATASET_PAIRS:
- _, _, best_score = \
- train_fn(args, model_name, hp, base_output_dir, hp_file, source_domain, target_domain)
-
- if best_score is not None and (hp_best_score is None or hp_best_score < best_score):
- hp_best = hp
- hp_best_score = best_score
-
- with open(hp_file, "a") as myfile:
- myfile.write(f"\nbest: {hp_best.lr}, {hp_best.gamma}\n")
-
-
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--max_epochs', default=60, type=int)
- parser.add_argument('--patience', default=10, type=int)
- parser.add_argument('--batch_size', default=32, type=int)
- parser.add_argument('--num_workers', default=2, type=int)
- parser.add_argument('--trials_count', default=3, type=int)
- parser.add_argument('--initial_trial', default=0, type=int)
- parser.add_argument('--download', default=False, type=bool)
- parser.add_argument('--root', default="../")
- parser.add_argument('--data_root', default="../datasets/pytorch-adapt/")
- parser.add_argument('--results_root', default="../results/")
- parser.add_argument('--model_names', default=["DANN"], nargs='+')
- parser.add_argument('--lr', default=None, type=float)
- parser.add_argument('--gamma', default=None, type=float)
- parser.add_argument('--hp_tune', default=False, type=bool)
- parser.add_argument('--source', default=None)
- parser.add_argument('--target', default=None)
- parser.add_argument('--vishook_frequency', default=5, type=int)
- parser.add_argument('--source_checkpoint_base_dir', default=None) # default='../results/DAModels.SOURCE/'
- parser.add_argument('--source_checkpoint_trial_number', default=-1, type=int)
-
- args = parser.parse_args()
-
- print(args)
-
- for trial_number in range(args.initial_trial, args.initial_trial + args.trials_count):
- if args.source_checkpoint_trial_number == -1:
- args.source_checkpoint_trial_number = trial_number
-
- for model_name in args.model_names:
- try:
- model_enum = DAModels(model_name)
- except ValueError:
- logging.warning(f"Model {model_name} not found. skipping...")
- continue
-
- run_experiment_on_model(args, model_enum, trial_number)
|