Browse Source

other hyper parameters

master
mohamad maheri 3 years ago
parent
commit
a9d1772942
2 changed files with 19 additions and 9 deletions
  1. 1
    1
      learnToLearn.py
  2. 18
    8
      learnToLearnTest.py

+ 1
- 1
learnToLearn.py View File

print("META LEARNING PHASE") print("META LEARNING PHASE")
# head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True) # head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
transform = l2l.optim.ModuleTransform(torch.nn.Linear) transform = l2l.optim.ModuleTransform(torch.nn.Linear)
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True)
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=False)


if use_cuda: if use_cuda:
head.cuda() head.cuda()

+ 18
- 8
learnToLearnTest.py View File

ndcgs3 = [] ndcgs3 = []


for iterator in range(test_set_size): for iterator in range(test_set_size):
try:
supp_xs = a[iterator].cuda()
supp_ys = b[iterator].cuda()
query_xs = c[iterator].cuda()
query_ys = d[iterator].cuda()
except IndexError:
print("index error in test method")
continue
if config['use_cuda']:
try:
supp_xs = a[iterator].cuda()
supp_ys = b[iterator].cuda()
query_xs = c[iterator].cuda()
query_ys = d[iterator].cuda()
except IndexError:
print("index error in test method")
continue
else:
try:
supp_xs = a[iterator]
supp_ys = b[iterator]
query_xs = c[iterator]
query_ys = d[iterator]
except IndexError:
print("index error in test method")
continue


num_local_update = config['inner'] num_local_update = config['inner']
learner = head.clone() learner = head.clone()

Loading…
Cancel
Save