|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- import os
- import torch
- import pickle
-
- from MeLU import MeLU
- from options import config
- from model_training import training
- from data_generation import generate
- from evidence_candidate import selection
- from model_test import test
- from embedding_module import EmbeddingModule
-
- import learn2learn as l2l
- from embeddings import item, user
- import random
- import numpy as np
- from learnToLearnTest import test
- from fast_adapt import fast_adapt
-
-
-
-
- # DATA GENERATION
- print("DATA GENERATION PHASE")
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = "1"
- master_path= "/media/external_10TB/10TB/maheri/melu_data5"
- if not os.path.exists("{}/".format(master_path)):
- os.mkdir("{}/".format(master_path))
- # preparing dataset. It needs about 22GB of your hard disk space.
- generate(master_path)
-
- # TRAINING
- print("TRAINING PHASE")
- embedding_dim = config['embedding_dim']
- fc1_in_dim = config['embedding_dim'] * 8
- fc2_in_dim = config['first_fc_hidden_dim']
- fc2_out_dim = config['second_fc_hidden_dim']
- use_cuda = config['use_cuda']
-
- emb = EmbeddingModule(config).cuda()
-
- fc1 = torch.nn.Linear(fc1_in_dim, fc2_in_dim)
- fc2 = torch.nn.Linear(fc2_in_dim, fc2_out_dim)
- linear_out = torch.nn.Linear(fc2_out_dim, 1)
- head = torch.nn.Sequential(fc1,fc2,linear_out)
-
-
- # META LEARNING
- print("META LEARNING PHASE")
- head = l2l.algorithms.MetaSGD(head, lr=config['local_lr'],first_order=True)
- head.cuda()
-
- # Setup optimization
- print("SETUP OPTIMIZATION PHASE")
- all_parameters = list(emb.parameters()) + list(head.parameters())
- optimizer = torch.optim.Adam(all_parameters, lr=config['lr'])
- # loss = torch.nn.MSELoss(reduction='mean')
-
- # Load training dataset.
- print("LOAD DATASET PHASE")
- training_set_size = int(len(os.listdir("{}/warm_state".format(master_path))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(training_set_size):
- supp_xs_s.append(pickle.load(open("{}/warm_state/supp_x_{}.pkl".format(master_path, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/warm_state/supp_y_{}.pkl".format(master_path, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/warm_state/query_x_{}.pkl".format(master_path, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/warm_state/query_y_{}.pkl".format(master_path, idx), "rb")))
- total_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
- training_set_size = len(total_dataset)
- batch_size = config['batch_size']
-
- print("\n\n\n")
- for iteration in range(config['num_epoch']):
- random.shuffle(total_dataset)
- num_batch = int(training_set_size / batch_size)
- a, b, c, d = zip(*total_dataset)
-
- for i in range(num_batch):
- optimizer.zero_grad()
- meta_train_error = 0.0
- meta_train_accuracy = 0.0
- meta_valid_error = 0.0
- meta_valid_accuracy = 0.0
- meta_test_error = 0.0
- meta_test_accuracy = 0.0
-
- print("EPOCH: ", iteration, " BATCH: ", i)
- supp_xs = list(a[batch_size * i:batch_size * (i + 1)])
- supp_ys = list(b[batch_size * i:batch_size * (i + 1)])
- query_xs = list(c[batch_size * i:batch_size * (i + 1)])
- query_ys = list(d[batch_size * i:batch_size * (i + 1)])
- batch_sz = len(supp_xs)
-
- for j in range(batch_size):
- supp_xs[j] = supp_xs[j].cuda()
- supp_ys[j] = supp_ys[j].cuda()
- query_xs[j] = query_xs[j].cuda()
- query_ys[j] = query_ys[j].cuda()
-
- for task in range(batch_sz):
- # print("EPOCH: ", iteration," BATCH: ",i, "TASK: ",task)
-
- # Compute meta-training loss
- learner = head.clone()
- temp_sxs = emb(supp_xs[task])
- temp_qxs = emb(query_xs[task])
-
- evaluation_error = fast_adapt(learner,
- temp_sxs,
- temp_qxs,
- supp_ys[task],
- query_ys[task],
- config['inner']
- )
-
- evaluation_error.backward()
- meta_train_error += evaluation_error.item()
-
-
- # Print some metrics
- print('Iteration', iteration)
- print('Meta Train Error', meta_train_error / batch_sz)
- # print('Meta Train Accuracy', meta_train_accuracy / batch_sz)
- # print('Meta Valid Error', meta_valid_error / batch_sz)
- # print('Meta Valid Accuracy', meta_valid_accuracy / batch_sz)
- # print('Meta Test Error', meta_test_error / batch_sz)
- # print('Meta Test Accuracy', meta_test_accuracy / batch_sz)
-
- # Average the accumulated gradients and optimize
- for p in all_parameters:
- p.grad.data.mul_(1.0 / batch_sz)
- optimizer.step()
-
- print("===============================================\n")
-
-
- # save model
- final_model = torch.nn.Sequential(emb,head)
- torch.save(final_model.state_dict(), master_path + "/models_sgd.pkl")
-
-
- # testing
- print("start of test phase")
- for test_state in ['warm_state', 'user_cold_state', 'item_cold_state', 'user_and_item_cold_state']:
- test_dataset = None
- test_set_size = int(len(os.listdir("{}/{}".format(master_path, test_state))) / 4)
- supp_xs_s = []
- supp_ys_s = []
- query_xs_s = []
- query_ys_s = []
- for idx in range(test_set_size):
- supp_xs_s.append(pickle.load(open("{}/{}/supp_x_{}.pkl".format(master_path, test_state, idx), "rb")))
- supp_ys_s.append(pickle.load(open("{}/{}/supp_y_{}.pkl".format(master_path, test_state, idx), "rb")))
- query_xs_s.append(pickle.load(open("{}/{}/query_x_{}.pkl".format(master_path, test_state, idx), "rb")))
- query_ys_s.append(pickle.load(open("{}/{}/query_y_{}.pkl".format(master_path, test_state, idx), "rb")))
- test_dataset = list(zip(supp_xs_s, supp_ys_s, query_xs_s, query_ys_s))
- del (supp_xs_s, supp_ys_s, query_xs_s, query_ys_s)
-
- print("===================== " + test_state + " =====================")
- test(emb,head, test_dataset, batch_size=config['batch_size'], num_epoch=config['num_epoch'])
- print("===================================================\n\n\n")
-
-
-
-
-
-
|