| @@ -16,14 +16,13 @@ import random | |||
| import numpy as np | |||
| from learnToLearnTest import test | |||
| from fast_adapt import fast_adapt | |||
| import gc | |||
| # DATA GENERATION | |||
| print("DATA GENERATION PHASE") | |||
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||
| os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |||
| master_path= "/media/external_10TB/10TB/maheri/melu_data5" | |||
| if not os.path.exists("{}/".format(master_path)): | |||
| os.mkdir("{}/".format(master_path)) | |||
| @@ -48,7 +47,10 @@ head = torch.nn.Sequential(fc1,fc2,linear_out) | |||
| # META LEARNING | |||
| print("META LEARNING PHASE") | |||
| head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||
| # head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) | |||
| transform = l2l.optim.ModuleTransform(torch.nn.Linear) | |||
| head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True) | |||
| # head.to(torch.device('cuda:0')) | |||
| head.cuda() | |||
| # Setup optimization | |||
| @@ -74,12 +76,14 @@ del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) | |||
| training_set_size = len(total_dataset) | |||
| batch_size = config['batch_size'] | |||
| torch.cuda.empty_cache() | |||
| random.shuffle(total_dataset) | |||
| num_batch = int(training_set_size / batch_size) | |||
| a, b, c, d = zip(*total_dataset) | |||
| print("\n\n\n") | |||
| for iteration in range(config['num_epoch']): | |||
| random.shuffle(total_dataset) | |||
| num_batch = int(training_set_size / batch_size) | |||
| a, b, c, d = zip(*total_dataset) | |||
| for i in range(num_batch): | |||
| optimizer.zero_grad() | |||
| meta_train_error = 0.0 | |||
| @@ -136,6 +140,10 @@ for iteration in range(config['num_epoch']): | |||
| p.grad.data.mul_(1.0 / batch_sz) | |||
| optimizer.step() | |||
| torch.cuda.empty_cache() | |||
| del(supp_xs,supp_ys,query_xs,query_ys) | |||
| gc.collect() | |||
| print("===============================================\n") | |||