| import numpy as np | import numpy as np | ||||
| from learnToLearnTest import test | from learnToLearnTest import test | ||||
| from fast_adapt import fast_adapt | from fast_adapt import fast_adapt | ||||
| import gc | |||||
| # DATA GENERATION | # DATA GENERATION | ||||
| print("DATA GENERATION PHASE") | print("DATA GENERATION PHASE") | ||||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
| os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |||||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |||||
| master_path= "/media/external_10TB/10TB/maheri/melu_data5" | master_path= "/media/external_10TB/10TB/maheri/melu_data5" | ||||
| if not os.path.exists("{}/".format(master_path)): | if not os.path.exists("{}/".format(master_path)): | ||||
| os.mkdir("{}/".format(master_path)) | os.mkdir("{}/".format(master_path)) | ||||
| # META LEARNING | # META LEARNING | ||||
| print("META LEARNING PHASE") | print("META LEARNING PHASE") | ||||
| head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||||
| # head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||||
| transform = l2l.optim.ModuleTransform(torch.nn.Linear) | |||||
| head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True) | |||||
| # head.to(torch.device('cuda:0')) | |||||
| head.cuda() | head.cuda() | ||||
| # Setup optimization | # Setup optimization | ||||
| training_set_size = len(total_dataset) | training_set_size = len(total_dataset) | ||||
| batch_size = config['batch_size'] | batch_size = config['batch_size'] | ||||
| torch.cuda.empty_cache() | |||||
| random.shuffle(total_dataset) | |||||
| num_batch = int(training_set_size / batch_size) | |||||
| a, b, c, d = zip(*total_dataset) | |||||
| print("\n\n\n") | print("\n\n\n") | ||||
| for iteration in range(config['num_epoch']): | for iteration in range(config['num_epoch']): | ||||
| random.shuffle(total_dataset) | |||||
| num_batch = int(training_set_size / batch_size) | |||||
| a, b, c, d = zip(*total_dataset) | |||||
| for i in range(num_batch): | for i in range(num_batch): | ||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| meta_train_error = 0.0 | meta_train_error = 0.0 | ||||
| p.grad.data.mul_(1.0 / batch_sz) | p.grad.data.mul_(1.0 / batch_sz) | ||||
| optimizer.step() | optimizer.step() | ||||
| torch.cuda.empty_cache() | |||||
| del(supp_xs,supp_ys,query_xs,query_ys) | |||||
| gc.collect() | |||||
| print("===============================================\n") | print("===============================================\n") | ||||