You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

run-script.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import argparse
  2. import math
  3. import time
  4. import numpy
  5. from tqdm import tqdm
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. import transformer.Constants as Constants
  10. from transformer.Models import Decoder
  11. from transformer.Optim import ScheduledOptim
  12. from DataLoader_Feat import DataLoader
  13. import pickle
  14. CUDA = 1
  15. parser = argparse.ArgumentParser()
  16. def get_performance(crit, pred, gold):
  17. loss = crit(pred, gold.contiguous().view(-1))
  18. pred = pred.max(1)[1]
  19. gold = gold.contiguous().view(-1)
  20. n_correct = pred.data.eq(gold.data)
  21. n_correct = n_correct.masked_select(gold.ne(Constants.PAD).data).sum()
  22. return loss, n_correct
  23. def closest_cluster(input_vector):
  24. min_result = float('-inf')
  25. min_index = -1
  26. for cluster_ind, cluster_vec in clusters_vecs.items():
  27. result = numpy.dot(cluster_vec, input_vector)
  28. # print("result:{0}".format(result))
  29. if min_result < result:
  30. min_result = result
  31. min_index = cluster_ind
  32. return min_index
  33. def get_criterion(user_size):
  34. ''' With PAD token zero weight '''
  35. weight = torch.ones(user_size)
  36. weight[Constants.PAD] = 0
  37. weight[Constants.EOS] = 0
  38. return nn.CrossEntropyLoss(weight, size_average=False)
  39. parser.add_argument('-epoch', type=int, default=75)
  40. parser.add_argument('-batch_size', type=int, default=4) # CHANGE 8->32
  41. parser.add_argument('-d_model', type=int, default=1)
  42. parser.add_argument('-d_inner_hid', type=int, default=64)
  43. parser.add_argument('-d_k', type=int, default=64)
  44. parser.add_argument('-d_v', type=int, default=64)
  45. parser.add_argument('-window_size', type=int, default=3)
  46. parser.add_argument('-finit', type=int, default=0)
  47. parser.add_argument('-n_head', type=int, default=8)
  48. parser.add_argument('-n_warmup_steps', type=int, default=1000)
  49. parser.add_argument('-dropout', type=float, default=0.1)
  50. parser.add_argument('-embs_share_weight', action='store_true')
  51. parser.add_argument('-proj_share_weight', action='store_true')
  52. parser.add_argument('-save_model', default='Lastfm_test')
  53. parser.add_argument('-no_cuda', action='store_true')
  54. # torch.cuda.set_device(1)
  55. opt = parser.parse_args()
  56. opt.cuda = not opt.no_cuda
  57. opt.d_word_vec = opt.d_model
  58. # ========= Preparing DataLoader =========#
  59. train_data = DataLoader(use_valid=False, batch_size=opt.batch_size, cuda=opt.cuda)
  60. # ========= Preparing Model =========#
  61. opt.user_size = train_data.user_size
  62. model = Decoder(
  63. opt.user_size,
  64. d_k=opt.d_k,
  65. d_v=opt.d_v,
  66. d_model=opt.d_model,
  67. d_word_vec=opt.d_word_vec,
  68. d_inner_hid=opt.d_inner_hid,
  69. n_head=opt.n_head,
  70. kernel_size=opt.window_size,
  71. dropout=opt.dropout)
  72. optimizer = ScheduledOptim(
  73. optim.Adam(
  74. model.parameters(),
  75. betas=(0.9, 0.98), eps=1e-09),
  76. opt.d_model, opt.n_warmup_steps)
  77. model.load_state_dict(torch.load(f"pa-our-twitterLastfm_test49.chkpt")["model"])
  78. model.cuda(CUDA)
  79. crit = get_criterion(train_data.user_size)
  80. crit.cuda(CUDA)
  81. model.eval()
  82. total_loss = 0
  83. n_total_words = 0
  84. n_total_correct = 0
  85. idx2vec = open('/media/external_3TB/3TB/ramezani/pmoini/Trial/data3/idx2vec.pickle', 'rb')
  86. idx2vec = pickle.load(idx2vec)
  87. #clusters_vecs = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/clusters_vectors.p', 'rb'))
  88. #user_to_cluster = pickle.load(open('../DeepDiffuseCode/user_embedding/dim_160/users_to_cluster.p', 'rb'))
  89. #id_to_user = pickle.load(open('./data/idx2u.pickle', 'rb'))
  90. for batch in tqdm(train_data, mininterval=2, desc=' - (Test) - ', leave=False):
  91. tgt = batch
  92. gold = tgt[:, 1:]
  93. h0 = torch.zeros(1, batch.size(0), 8 * opt.d_word_vec)
  94. if opt.cuda:
  95. h0 = h0.cuda(CUDA)
  96. pred, *_ = model(tgt, h0)
  97. loss, n_correct = get_performance(crit, pred, gold)
  98. # note keeping
  99. n_words = gold.data.ne(Constants.PAD).sum()
  100. n_total_words += n_words
  101. n_total_correct += n_correct
  102. # total_loss += loss.item()
  103. print(n_total_correct / n_total_words)
  104. # total_loss = total_loss.data.cpu().numpy()
  105. # n_total_correct = n_total_correct.data.cpu().numpy()
  106. # n_total_words = n_total_words.data.cpu().numpy()