Browse Source

other hyper parameters

master
mohamad maheri 2 years ago
parent
commit
a9d1772942
2 changed files with 19 additions and 9 deletions
  1. 1
    1
      learnToLearn.py
  2. 18
    8
      learnToLearnTest.py

+ 1
- 1
learnToLearn.py View File

@@ -52,7 +52,7 @@ else:
print("META LEARNING PHASE")
# head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
transform = l2l.optim.ModuleTransform(torch.nn.Linear)
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=True)
head = l2l.algorithms.GBML(head , transform=transform , lr=config['local_lr'] , adapt_transform=True,first_order=False)

if use_cuda:
head.cuda()

+ 18
- 8
learnToLearnTest.py View File

@@ -20,14 +20,24 @@ def test(embedding,head, total_dataset, batch_size, num_epoch):
ndcgs3 = []

for iterator in range(test_set_size):
try:
supp_xs = a[iterator].cuda()
supp_ys = b[iterator].cuda()
query_xs = c[iterator].cuda()
query_ys = d[iterator].cuda()
except IndexError:
print("index error in test method")
continue
if config['use_cuda']:
try:
supp_xs = a[iterator].cuda()
supp_ys = b[iterator].cuda()
query_xs = c[iterator].cuda()
query_ys = d[iterator].cuda()
except IndexError:
print("index error in test method")
continue
else:
try:
supp_xs = a[iterator]
supp_ys = b[iterator]
query_xs = c[iterator]
query_ys = d[iterator]
except IndexError:
print("index error in test method")
continue

num_local_update = config['inner']
learner = head.clone()

Loading…
Cancel
Save