''' This script handling the training process. ''' import argparse import math import time import pickle from tqdm import tqdm import torch import torch.nn as nn import torch.optim as optim import transformer.Constants as Constants from transformer.Models import Decoder from transformer.Optim import ScheduledOptim from DataLoader_Feat import DataLoader CUDA = 0 def get_criterion(user_size): ''' With PAD token zero weight ''' weight = torch.ones(user_size) weight[Constants.PAD] = 0 weight[Constants.EOS] = 0 return nn.CrossEntropyLoss(weight, size_average=False) def get_performance(crit, pred, gold, smoothing=False, num_class=None): loss = crit(pred, gold.contiguous().view(-1)) pred = pred.max(1)[1] gold = gold.contiguous().view(-1) n_correct = pred.data.eq(gold.data) n_correct = n_correct.masked_select(gold.ne(Constants.PAD).data).sum() return loss, n_correct def train_epoch(model, training_data, crit, opt): ''' Epoch operation in training phase''' # model.train() total_loss = 0 n_total_words = 0 n_total_correct = 0 for batch in tqdm( training_data, mininterval=2, desc=' - (Training) ', leave=False): # prepare data tgt = batch gold = tgt[:, 1:] h0 = torch.zeros(1, batch.size(0), 8 * opt.d_word_vec) if opt.cuda: h0 = h0.cuda(CUDA) pred, *_ = model(tgt, h0) # backward loss, n_correct = get_performance(crit, pred, gold) n_words = gold.data.ne(Constants.PAD).sum() n_total_words += n_words n_total_correct += n_correct # total_loss += loss.item() print("--------------") print(n_total_correct / n_total_words) print("-------------") # print("--------------") # print(loss) # print(n_correct) # print("-------------") # total_loss = total_loss.data.cpu().numpy() n_total_correct = n_total_correct.data.cpu().numpy() n_total_words = n_total_words.data.cpu().numpy() print(n_total_correct / n_total_words) # return total_loss / n_total_words, n_total_correct / n_total_words def main(): ''' Main function''' parser = argparse.ArgumentParser() parser.add_argument('-epoch', type=int, default=50) parser.add_argument('-batch_size', type=int, default=4) parser.add_argument('-d_model', type=int, default=1) parser.add_argument('-d_inner_hid', type=int, default=64) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-window_size', type=int, default=3) parser.add_argument('-finit', type=int, default=0) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_warmup_steps', type=int, default=1000) parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true') parser.add_argument('-proj_share_weight', action='store_true') parser.add_argument('-save_model', default='Lastfm_test') parser.add_argument('-no_cuda', action='store_true') # torch.cuda.set_device(1) opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model # ========= Preparing DataLoader =========# train_data = DataLoader(use_valid=False, load_dict=True, batch_size=opt.batch_size, cuda=opt.cuda) opt.user_size = train_data.user_size # ========= Preparing Model =========# model = Decoder( opt.user_size, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner_hid=opt.d_inner_hid, n_head=opt.n_head, kernel_size=opt.window_size, dropout=opt.dropout) model.load_state_dict(torch.load(f"pa-our-twitterLastfm_test49.chkpt")["model"]) model.cuda(CUDA) crit = get_criterion(train_data.user_size) crit.cuda(CUDA) model.eval() total_loss = 0 n_total_words = 0 n_total_correct = 0 idx2vec = open('/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2vec.pickle', 'rb') idx2vec = pickle.load(idx2vec) # clusters_vecs = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/clusters_vectors.p', 'rb')) # user_to_cluster = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/users_to_cluster.p', 'rb')) id_to_user = pickle.load(open('/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2u.pickle', 'rb')) crit = get_criterion(train_data.user_size) if opt.cuda: model = model.cuda(CUDA) crit = crit.cuda(CUDA) train_epoch(model, train_data, crit, opt) if __name__ == '__main__': main()