|
|
@@ -20,14 +20,24 @@ def test(embedding,head, total_dataset, batch_size, num_epoch): |
|
|
|
ndcgs3 = [] |
|
|
|
|
|
|
|
for iterator in range(test_set_size): |
|
|
|
try: |
|
|
|
supp_xs = a[iterator].cuda() |
|
|
|
supp_ys = b[iterator].cuda() |
|
|
|
query_xs = c[iterator].cuda() |
|
|
|
query_ys = d[iterator].cuda() |
|
|
|
except IndexError: |
|
|
|
print("index error in test method") |
|
|
|
continue |
|
|
|
if config['use_cuda']: |
|
|
|
try: |
|
|
|
supp_xs = a[iterator].cuda() |
|
|
|
supp_ys = b[iterator].cuda() |
|
|
|
query_xs = c[iterator].cuda() |
|
|
|
query_ys = d[iterator].cuda() |
|
|
|
except IndexError: |
|
|
|
print("index error in test method") |
|
|
|
continue |
|
|
|
else: |
|
|
|
try: |
|
|
|
supp_xs = a[iterator] |
|
|
|
supp_ys = b[iterator] |
|
|
|
query_xs = c[iterator] |
|
|
|
query_ys = d[iterator] |
|
|
|
except IndexError: |
|
|
|
print("index error in test method") |
|
|
|
continue |
|
|
|
|
|
|
|
num_local_update = config['inner'] |
|
|
|
learner = head.clone() |