|
|
@@ -16,14 +16,13 @@ import random |
|
|
|
import numpy as np |
|
|
|
from learnToLearnTest import test |
|
|
|
from fast_adapt import fast_adapt |
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
|
|
|
|
|
|
|
|
|
|
# DATA GENERATION |
|
|
|
print("DATA GENERATION PHASE") |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
master_path= "/media/external_10TB/10TB/maheri/melu_data5" |
|
|
|
if not os.path.exists("{}/".format(master_path)): |
|
|
|
os.mkdir("{}/".format(master_path)) |
|
|
@@ -48,7 +47,10 @@ head = torch.nn.Sequential(fc1,fc2,linear_out) |
|
|
|
|
|
|
|
# META LEARNING |
|
|
|
print("META LEARNING PHASE") |
|
|
|
head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) |
|
|
|
# head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) |
|
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
|
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True) |
|
|
|
# head.to(torch.device('cuda:0')) |
|
|
|
head.cuda() |
|
|
|
|
|
|
|
# Setup optimization |
|
|
@@ -74,12 +76,14 @@ del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) |
|
|
|
training_set_size = len(total_dataset) |
|
|
|
batch_size = config['batch_size'] |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
random.shuffle(total_dataset) |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
a, b, c, d = zip(*total_dataset) |
|
|
|
|
|
|
|
print("\n\n\n") |
|
|
|
for iteration in range(config['num_epoch']): |
|
|
|
random.shuffle(total_dataset) |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
|
a, b, c, d = zip(*total_dataset) |
|
|
|
|
|
|
|
for i in range(num_batch): |
|
|
|
optimizer.zero_grad() |
|
|
|
meta_train_error = 0.0 |
|
|
@@ -136,6 +140,10 @@ for iteration in range(config['num_epoch']): |
|
|
|
p.grad.data.mul_(1.0 / batch_sz) |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
del(supp_xs,supp_ys,query_xs,query_ys) |
|
|
|
gc.collect() |
|
|
|
|
|
|
|
print("===============================================\n") |
|
|
|
|
|
|
|
|