You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_code.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. '''
  2. This script handling the training process.
  3. '''
  4. import argparse
  5. import math
  6. import time
  7. import pickle
  8. from tqdm import tqdm
  9. import torch
  10. import torch.nn as nn
  11. import torch.optim as optim
  12. import transformer.Constants as Constants
  13. from transformer.Models import Decoder
  14. from transformer.Optim import ScheduledOptim
  15. from DataLoader_Feat import DataLoader
  16. CUDA = 0
  17. def get_criterion(user_size):
  18. ''' With PAD token zero weight '''
  19. weight = torch.ones(user_size)
  20. weight[Constants.PAD] = 0
  21. weight[Constants.EOS] = 0
  22. return nn.CrossEntropyLoss(weight, size_average=False)
  23. def get_performance(crit, pred, gold, smoothing=False, num_class=None):
  24. loss = crit(pred, gold.contiguous().view(-1))
  25. pred = pred.max(1)[1]
  26. gold = gold.contiguous().view(-1)
  27. n_correct = pred.data.eq(gold.data)
  28. n_correct = n_correct.masked_select(gold.ne(Constants.PAD).data).sum()
  29. return loss, n_correct
  30. def train_epoch(model, training_data, crit, opt):
  31. ''' Epoch operation in training phase'''
  32. # model.train()
  33. total_loss = 0
  34. n_total_words = 0
  35. n_total_correct = 0
  36. for batch in tqdm(
  37. training_data, mininterval=2,
  38. desc=' - (Training) ', leave=False):
  39. # prepare data
  40. tgt = batch
  41. gold = tgt[:, 1:]
  42. h0 = torch.zeros(1, batch.size(0), 8 * opt.d_word_vec)
  43. if opt.cuda:
  44. h0 = h0.cuda(CUDA)
  45. pred, *_ = model(tgt, h0)
  46. # backward
  47. loss, n_correct = get_performance(crit, pred, gold)
  48. n_words = gold.data.ne(Constants.PAD).sum()
  49. n_total_words += n_words
  50. n_total_correct += n_correct
  51. # total_loss += loss.item()
  52. print("--------------")
  53. print(n_total_correct / n_total_words)
  54. print("-------------")
  55. # print("--------------")
  56. # print(loss)
  57. # print(n_correct)
  58. # print("-------------")
  59. # total_loss = total_loss.data.cpu().numpy()
  60. n_total_correct = n_total_correct.data.cpu().numpy()
  61. n_total_words = n_total_words.data.cpu().numpy()
  62. print(n_total_correct / n_total_words)
  63. # return total_loss / n_total_words, n_total_correct / n_total_words
  64. def main():
  65. ''' Main function'''
  66. parser = argparse.ArgumentParser()
  67. parser.add_argument('-epoch', type=int, default=50)
  68. parser.add_argument('-batch_size', type=int, default=4)
  69. parser.add_argument('-d_model', type=int, default=1)
  70. parser.add_argument('-d_inner_hid', type=int, default=64)
  71. parser.add_argument('-d_k', type=int, default=64)
  72. parser.add_argument('-d_v', type=int, default=64)
  73. parser.add_argument('-window_size', type=int, default=3)
  74. parser.add_argument('-finit', type=int, default=0)
  75. parser.add_argument('-n_head', type=int, default=8)
  76. parser.add_argument('-n_warmup_steps', type=int, default=1000)
  77. parser.add_argument('-dropout', type=float, default=0.1)
  78. parser.add_argument('-embs_share_weight', action='store_true')
  79. parser.add_argument('-proj_share_weight', action='store_true')
  80. parser.add_argument('-save_model', default='Lastfm_test')
  81. parser.add_argument('-no_cuda', action='store_true')
  82. # torch.cuda.set_device(1)
  83. opt = parser.parse_args()
  84. opt.cuda = not opt.no_cuda
  85. opt.d_word_vec = opt.d_model
  86. # ========= Preparing DataLoader =========#
  87. train_data = DataLoader(use_valid=False, load_dict=True, batch_size=opt.batch_size, cuda=opt.cuda)
  88. opt.user_size = train_data.user_size
  89. # ========= Preparing Model =========#
  90. model = Decoder(
  91. opt.user_size,
  92. d_k=opt.d_k,
  93. d_v=opt.d_v,
  94. d_model=opt.d_model,
  95. d_word_vec=opt.d_word_vec,
  96. d_inner_hid=opt.d_inner_hid,
  97. n_head=opt.n_head,
  98. kernel_size=opt.window_size,
  99. dropout=opt.dropout)
  100. model.load_state_dict(torch.load(f"pa-our-twitterLastfm_test49.chkpt")["model"])
  101. model.cuda(CUDA)
  102. crit = get_criterion(train_data.user_size)
  103. crit.cuda(CUDA)
  104. model.eval()
  105. total_loss = 0
  106. n_total_words = 0
  107. n_total_correct = 0
  108. idx2vec = open('/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2vec.pickle', 'rb')
  109. idx2vec = pickle.load(idx2vec)
  110. # clusters_vecs = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/clusters_vectors.p', 'rb'))
  111. # user_to_cluster = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/users_to_cluster.p', 'rb'))
  112. id_to_user = pickle.load(open('/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2u.pickle', 'rb'))
  113. crit = get_criterion(train_data.user_size)
  114. if opt.cuda:
  115. model = model.cuda(CUDA)
  116. crit = crit.cuda(CUDA)
  117. train_epoch(model, train_data, crit, opt)
  118. if __name__ == '__main__':
  119. main()