|  |  | @@ -32,9 +32,9 @@ def parse_args(): | 
		
	
		
			
			|  |  |  | help='outer-loop learning rate (used with Adam optimiser)') | 
		
	
		
			
			|  |  |  | # parser.add_argument('--lr_meta_decay', type=float, default=0.9, help='decay factor for meta learning rate') | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | parser.add_argument('--inner', type=int, default=2, | 
		
	
		
			
			|  |  |  | parser.add_argument('--inner', type=int, default=1, | 
		
	
		
			
			|  |  |  | help='number of gradient steps in inner loop (during training)') | 
		
	
		
			
			|  |  |  | parser.add_argument('--inner_eval', type=int, default=2, | 
		
	
		
			
			|  |  |  | parser.add_argument('--inner_eval', type=int, default=1, | 
		
	
		
			
			|  |  |  | help='number of gradient updates at test time (for evaluation)') | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | parser.add_argument('--first_order', action='store_true', default=False, | 
		
	
	
		
			
			|  |  | @@ -102,10 +102,10 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | fc2_out_dim = config['second_fc_hidden_dim'] | 
		
	
		
			
			|  |  |  | use_cuda = config['use_cuda'] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) | 
		
	
		
			
			|  |  |  | fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) | 
		
	
		
			
			|  |  |  | linear_out = torch.nn.Linear(fc2_out_dim, 1) | 
		
	
		
			
			|  |  |  | head = torch.nn.Sequential(fc1, fc2, linear_out) | 
		
	
		
			
			|  |  |  | # fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim) | 
		
	
		
			
			|  |  |  | # fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim) | 
		
	
		
			
			|  |  |  | # linear_out = torch.nn.Linear(fc2_out_dim, 1) | 
		
	
		
			
			|  |  |  | # head = torch.nn.Sequential(fc1, fc2, linear_out) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if use_cuda: | 
		
	
		
			
			|  |  |  | emb = EmbeddingModule(config).cuda() | 
		
	
	
		
			
			|  |  | @@ -123,6 +123,7 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | transform = l2l.optim.ModuleTransform(torch.nn.Linear) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | trainer = Trainer(config) | 
		
	
		
			
			|  |  |  | tr = trainer | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # define meta algorithm | 
		
	
		
			
			|  |  |  | if args.meta_algo == "maml": | 
		
	
	
		
			
			|  |  | @@ -140,7 +141,7 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | # Setup optimization | 
		
	
		
			
			|  |  |  | print("SETUP OPTIMIZATION PHASE") | 
		
	
		
			
			|  |  |  | all_parameters = list(emb.parameters()) + list(trainer.parameters()) | 
		
	
		
			
			|  |  |  | optimizer = torch.optim.Adam(all_parameters, lr=args.lr_meta) | 
		
	
		
			
			|  |  |  | optimizer = torch.optim.Adam(all_parameters, lr=config['lr']) | 
		
	
		
			
			|  |  |  | # loss = torch.nn.MSELoss(reduction='mean') | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Load training dataset. | 
		
	
	
		
			
			|  |  | @@ -156,7 +157,47 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | print("\n\n\n") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for iteration in range(args.epochs): | 
		
	
		
			
			|  |  |  | for iteration in range(config['num_epoch']): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if iteration == 1: | 
		
	
		
			
			|  |  |  | print("changing cluster centroids started ...") | 
		
	
		
			
			|  |  |  | indexes = list(np.arange(training_set_size)) | 
		
	
		
			
			|  |  |  | supp_xs, supp_ys, query_xs, query_ys = [], [], [], [] | 
		
	
		
			
			|  |  |  | for idx in range(0, 2500): | 
		
	
		
			
			|  |  |  | supp_xs.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | 
		
	
		
			
			|  |  |  | supp_ys.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | 
		
	
		
			
			|  |  |  | query_xs.append( | 
		
	
		
			
			|  |  |  | pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, indexes[idx]), "rb"))) | 
		
	
		
			
			|  |  |  | query_ys.append( | 
		
	
		
			
			|  |  |  | pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, indexes[idx]), "rb"))) | 
		
	
		
			
			|  |  |  | batch_sz = len(supp_xs) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | user_embeddings = [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for task in range(batch_sz): | 
		
	
		
			
			|  |  |  | # Compute meta-training loss | 
		
	
		
			
			|  |  |  | supp_xs[task] = supp_xs[task].cuda() | 
		
	
		
			
			|  |  |  | supp_ys[task] = supp_ys[task].cuda() | 
		
	
		
			
			|  |  |  | # query_xs[task] = query_xs[task].cuda() | 
		
	
		
			
			|  |  |  | # query_ys[task] = query_ys[task].cuda() | 
		
	
		
			
			|  |  |  | temp_sxs = emb(supp_xs[task]) | 
		
	
		
			
			|  |  |  | # temp_qxs = emb(query_xs[task]) | 
		
	
		
			
			|  |  |  | y = supp_ys[task].view(-1, 1) | 
		
	
		
			
			|  |  |  | input_pairs = torch.cat((temp_sxs, y), dim=1) | 
		
	
		
			
			|  |  |  | task_embed = tr.cluster_module.input_to_hidden(input_pairs) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # todo : may be useless | 
		
	
		
			
			|  |  |  | mean_task = tr.cluster_module.aggregate(task_embed) | 
		
	
		
			
			|  |  |  | user_embeddings.append(mean_task.detach().cpu().numpy()) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | supp_xs[task] = supp_xs[task].cpu() | 
		
	
		
			
			|  |  |  | supp_ys[task] = supp_ys[task].cpu() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | from sklearn.cluster import KMeans | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | user_embeddings = np.array(user_embeddings) | 
		
	
		
			
			|  |  |  | kmeans_model = KMeans(n_clusters=config['cluster_k'], init="k-means++").fit(user_embeddings) | 
		
	
		
			
			|  |  |  | tr.cluster_module.array.data = torch.Tensor(kmeans_model.cluster_centers_).cuda() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | num_batch = int(training_set_size / batch_size) | 
		
	
		
			
			|  |  |  | indexes = list(np.arange(training_set_size)) | 
		
	
	
		
			
			|  |  | @@ -199,7 +240,8 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | temp_qxs, | 
		
	
		
			
			|  |  |  | supp_ys[task], | 
		
	
		
			
			|  |  |  | query_ys[task], | 
		
	
		
			
			|  |  |  | args.inner) | 
		
	
		
			
			|  |  |  | config['inner'], | 
		
	
		
			
			|  |  |  | epoch=iteration) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | evaluation_error.backward() | 
		
	
		
			
			|  |  |  | meta_train_error += evaluation_error.item() | 
		
	
	
		
			
			|  |  | @@ -223,7 +265,7 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | gc.collect() | 
		
	
		
			
			|  |  |  | print("===============================================\n") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if iteration % 2 == 0: | 
		
	
		
			
			|  |  |  | if iteration % 2 == 0 and iteration != 0: | 
		
	
		
			
			|  |  |  | # testing | 
		
	
		
			
			|  |  |  | print("start of test phase") | 
		
	
		
			
			|  |  |  | trainer.eval() | 
		
	
	
		
			
			|  |  | @@ -241,7 +283,8 @@ if __name__ == '__main__': | 
		
	
		
			
			|  |  |  | gc.collect() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | print("===================== " + test_state + " =====================") | 
		
	
		
			
			|  |  |  | mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'],num_epoch=config['num_epoch'],test_state=test_state,args=args) | 
		
	
		
			
			|  |  |  | mse, ndc1, ndc3 = test(emb, trainer, test_dataset, batch_size=config['batch_size'], | 
		
	
		
			
			|  |  |  | num_epoch=config['num_epoch'], test_state=test_state, args=args) | 
		
	
		
			
			|  |  |  | with open("results2.txt", "a") as f: | 
		
	
		
			
			|  |  |  | f.write("{}\t{}\t{}\n".format(mse, ndc1, ndc3)) | 
		
	
		
			
			|  |  |  | print("===================================================") |