|  |  | @@ -20,14 +20,24 @@ def test(embedding,head, total_dataset, batch_size, num_epoch): | 
		
	
		
			
			|  |  |  | ndcgs3 = [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for iterator in range(test_set_size): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | supp_xs = a[iterator].cuda() | 
		
	
		
			
			|  |  |  | supp_ys = b[iterator].cuda() | 
		
	
		
			
			|  |  |  | query_xs = c[iterator].cuda() | 
		
	
		
			
			|  |  |  | query_ys = d[iterator].cuda() | 
		
	
		
			
			|  |  |  | except IndexError: | 
		
	
		
			
			|  |  |  | print("index error in test method") | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | if config['use_cuda']: | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | supp_xs = a[iterator].cuda() | 
		
	
		
			
			|  |  |  | supp_ys = b[iterator].cuda() | 
		
	
		
			
			|  |  |  | query_xs = c[iterator].cuda() | 
		
	
		
			
			|  |  |  | query_ys = d[iterator].cuda() | 
		
	
		
			
			|  |  |  | except IndexError: | 
		
	
		
			
			|  |  |  | print("index error in test method") | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | supp_xs = a[iterator] | 
		
	
		
			
			|  |  |  | supp_ys = b[iterator] | 
		
	
		
			
			|  |  |  | query_xs = c[iterator] | 
		
	
		
			
			|  |  |  | query_ys = d[iterator] | 
		
	
		
			
			|  |  |  | except IndexError: | 
		
	
		
			
			|  |  |  | print("index error in test method") | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | num_local_update = config['inner'] | 
		
	
		
			
			|  |  |  | learner = head.clone() |