|
|
@@ -18,12 +18,13 @@ from learnToLearnTest import test |
|
|
|
from fast_adapt import fast_adapt |
|
|
|
import gc |
|
|
|
|
|
|
|
if config['use_cuda']: |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
master_path= "/media/external_10TB/10TB/maheri/melu_data5" |
|
|
|
|
|
|
|
# DATA GENERATION |
|
|
|
print("DATA GENERATION PHASE") |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
master_path= "/media/external_10TB/10TB/maheri/melu_data5" |
|
|
|
if not os.path.exists("{}/".format(master_path)): |
|
|
|
os.mkdir("{}/".format(master_path)) |
|
|
|
# preparing dataset. It needs about 22GB of your hard disk space. |
|
|
@@ -37,21 +38,24 @@ fc2_in_dim = config['first_fc_hidden_dim'] |
|
|
|
fc2_out_dim = config['second_fc_hidden_dim'] |
|
|
|
use_cuda = config['use_cuda'] |
|
|
|
|
|
|
|
emb = EmbeddingModule(config).cuda() |
|
|
|
|
|
|
|
fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) |
|
|
|
fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) |
|
|
|
linear_out = torch.nn.Linear(fc2_out_dim, 1) |
|
|
|
head = torch.nn.Sequential(fc1,fc2,linear_out) |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
emb = EmbeddingModule(config).cuda() |
|
|
|
else: |
|
|
|
emb = EmbeddingModule(config) |
|
|
|
|
|
|
|
# META LEARNING |
|
|
|
print("META LEARNING PHASE") |
|
|
|
# head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) |
|
|
|
transform = l2l.optim.ModuleTransform(torch.nn.Linear) |
|
|
|
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True) |
|
|
|
# head.to(torch.device('cuda:0')) |
|
|
|
head.cuda() |
|
|
|
|
|
|
|
if use_cuda: |
|
|
|
head.cuda() |
|
|
|
|
|
|
|
# Setup optimization |
|
|
|
print("SETUP OPTIMIZATION PHASE") |
|
|
@@ -75,8 +79,7 @@ total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)) |
|
|
|
del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s) |
|
|
|
training_set_size = len(total_dataset) |
|
|
|
batch_size = config['batch_size'] |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
# torch.cuda.empty_cache() |
|
|
|
|
|
|
|
random.shuffle(total_dataset) |
|
|
|
num_batch = int(training_set_size / batch_size) |
|
|
@@ -100,15 +103,15 @@ for iteration in range(config['num_epoch']): |
|
|
|
query_ys = list(d[batch_size * i:batch_size * (i + 1)]) |
|
|
|
batch_sz = len(supp_xs) |
|
|
|
|
|
|
|
for j in range(batch_size): |
|
|
|
supp_xs[j] = supp_xs[j].cuda() |
|
|
|
supp_ys[j] = supp_ys[j].cuda() |
|
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
|
if use_cuda: |
|
|
|
for j in range(batch_size): |
|
|
|
supp_xs[j] = supp_xs[j].cuda() |
|
|
|
supp_ys[j] = supp_ys[j].cuda() |
|
|
|
query_xs[j] = query_xs[j].cuda() |
|
|
|
query_ys[j] = query_ys[j].cuda() |
|
|
|
|
|
|
|
for task in range(batch_sz): |
|
|
|
# print("EPOCH: ", iteration," BATCH: ",i, "TASK: ",task) |
|
|
|
|
|
|
|
# Compute meta-training loss |
|
|
|
learner = head.clone() |
|
|
|
temp_sxs = emb(supp_xs[task]) |
|
|
@@ -140,7 +143,7 @@ for iteration in range(config['num_epoch']): |
|
|
|
p.grad.data.mul_(1.0 / batch_sz) |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
# torch.cuda.empty_cache() |
|
|
|
del(supp_xs,supp_ys,query_xs,query_ys) |
|
|
|
gc.collect() |
|
|
|
|
|
|
@@ -149,8 +152,7 @@ for iteration in range(config['num_epoch']): |
|
|
|
|
|
|
|
# save model |
|
|
|
final_model = torch.nn.Sequential(emb,head) |
|
|
|
torch.save(final_model.state_dict(), master_path + "/models_sgd.pkl") |
|
|
|
|
|
|
|
torch.save(final_model.state_dict(), master_path + "/models_gbml.pkl") |
|
|
|
|
|
|
|
# testing |
|
|
|
print("start of test phase") |