Browse Source

GBML + some optimization for cuda memory

master
mohamad maheri 3 years ago
parent
commit
a4e913f57a
1 changed files with 16 additions and 8 deletions
  1. 16
    8
      learnToLearn.py

+ 16
- 8
learnToLearn.py View File

@@ -16,14 +16,13 @@ import random
import numpy as np
from learnToLearnTest import test
from fast_adapt import fast_adapt


import gc


# DATA GENERATION
print("DATA GENERATION PHASE")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
master_path= "/media/external_10TB/10TB/maheri/melu_data5"
if not os.path.exists("{}/".format(master_path)):
os.mkdir("{}/".format(master_path))
@@ -48,7 +47,10 @@ head = torch.nn.Sequential(fc1,fc2,linear_out)

# META LEARNING
print("META LEARNING PHASE")
head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
# head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
transform = l2l.optim.ModuleTransform(torch.nn.Linear)
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True)
# head.to(torch.device('cuda:0'))
head.cuda()

# Setup optimization
@@ -74,12 +76,14 @@ del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
training_set_size = len(total_dataset)
batch_size = config['batch_size']

torch.cuda.empty_cache()

random.shuffle(total_dataset)
num_batch = int(training_set_size / batch_size)
a, b, c, d = zip(*total_dataset)

print("\n\n\n")
for iteration in range(config['num_epoch']):
random.shuffle(total_dataset)
num_batch = int(training_set_size / batch_size)
a, b, c, d = zip(*total_dataset)

for i in range(num_batch):
optimizer.zero_grad()
meta_train_error = 0.0
@@ -136,6 +140,10 @@ for iteration in range(config['num_epoch']):
p.grad.data.mul_(1.0 / batch_sz)
optimizer.step()

torch.cuda.empty_cache()
del(supp_xs,supp_ys,query_xs,query_ys)
gc.collect()

print("===============================================\n")



Loading…
Cancel
Save