You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate_new.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. ''' Translate input text with trained model. '''
  2. import torch
  3. import random
  4. import argparse
  5. from tqdm import tqdm
  6. import numpy as np
  7. from transformer.Translator_new import Translator
  8. import transformer.Constants as Constants
  9. from DataLoader_Feat import DataLoader
  10. import pickle
  11. idx2vec_addr = '/media/external_3TB/3TB/ramezani/pmoini/Trial/data_d8/idx2vec.pickle'
  12. def getF1(ground_truth, pred_cnt):
  13. right = np.dot(ground_truth, pred_cnt)
  14. pred = np.sum(pred_cnt)
  15. total = np.sum(ground_truth)
  16. print(right, pred, total)
  17. if pred == 0:
  18. return 0, 0, 0
  19. precision = right / pred
  20. recall = right / total
  21. if precision == 0 or recall == 0:
  22. return 0, 0, 0
  23. return (2 * precision * recall) / (precision + recall), precision, recall
  24. def getMSE(ground_truth, pred_cnt):
  25. return mean_squared_error(ground_truth, pred_cnt)
  26. def getMAP(ground_truth, pred_cnt):
  27. size_cascade = list(ground_truth).count(1)
  28. avgp = 0.
  29. ite_list = [[idx, item, ground_truth[idx]] for idx, item in enumerate(pred_cnt)]
  30. ite_list = sorted(ite_list, key=lambda x: x[1], reverse=True)
  31. n_positive = 0
  32. idx = 0
  33. while n_positive < size_cascade:
  34. if ite_list[idx][2] == 1:
  35. n_positive += 1
  36. pre = n_positive / (idx + 1.)
  37. avgp = avgp + pre
  38. idx += 1
  39. # if idx >= 1:
  40. # break
  41. avgp = avgp / size_cascade
  42. return avgp
  43. def main():
  44. '''Main Function'''
  45. parser = argparse.ArgumentParser(description='translate.py')
  46. parser.add_argument('-model', default="/media/external_3TB/3TB/ramezani/pmoini/tmp_debug_NDM/NeuralDiffusionModel-master/pa-our-twitterLastfm_test49.chkpt",
  47. help='Path to model .pt file')
  48. '''
  49. parser.add_argument('-src', required=True,
  50. help='Source sequence to decode (one line per sequence)')
  51. parser.add_argument('-vocab', required=True,
  52. help='Source sequence to decode (one line per sequence)')
  53. '''
  54. parser.add_argument('-output', default='pred.txt',
  55. help="""Path to output the predictions (each line will
  56. be the decoded sequence""")
  57. parser.add_argument('-beam_size', type=int, default=100,
  58. help='Beam size')
  59. parser.add_argument('-batch_size', type=int, default=1,
  60. help='Batch size')
  61. parser.add_argument('-n_best', type=int, default=1,
  62. help="""If verbose is set, will output the n_best
  63. decoded sentences""")
  64. parser.add_argument('-no_cuda', action='store_true')
  65. # parser.add_argument('-seq_len', type=int ,default=20)
  66. parser.add_argument('-pred_len', type=int, default=15)
  67. opt = parser.parse_args()
  68. opt.cuda = not opt.no_cuda
  69. # Prepare DataLoader
  70. test_data = DataLoader(use_valid=False, batch_size=opt.batch_size, cuda=opt.cuda)
  71. translator = Translator(opt)
  72. translator.model.eval()
  73. numuser = test_data.user_size
  74. num_right = 0
  75. num_total = 0
  76. avgF1 = 0
  77. avgPre = 0
  78. avgRec = 0
  79. avgF1_long = 0
  80. avgPre_long = 0
  81. avgRec_long = 0
  82. avgF1_short = 0
  83. avgPre_short = 0
  84. avgRec_short = 0
  85. numseq = 0 # number of test seqs
  86. # for micro pre rec f1
  87. right = 0.
  88. pred = 0.
  89. total = 0.
  90. right_long = 0.
  91. pred_long = 0.
  92. total_long = 0.
  93. right_short = 0.
  94. pred_short = 0.
  95. total_short = 0.
  96. with open(opt.output, 'w') as f:
  97. with open(idx2vec_addr, 'rb') as handle:
  98. idx2vec = pickle.load(handle)
  99. for (src, src_mask, src_lengths, trg, trg_y, trg_mask, trg_lengths) in tqdm(test_data, mininterval=2,
  100. desc=' - (Test)', leave=False):
  101. # src = torch.FloatTensor([[ idx2vec[int(idx.data.cpu().numpy())] for idx in src[i]] for i in range(src.size(0))]).cuda(0)
  102. all_samples = translator.translate_batch(src, src_mask, src_lengths, trg, trg_mask, trg_lengths).data
  103. for bid in range(src.size(0)):
  104. numseq += 1.0
  105. ground_truth = np.zeros([numuser])
  106. num_ground_truth = 0
  107. for user in trg_y.data[bid][1:-1]:
  108. if user == Constants.EOS or user == Constants.PAD:
  109. break
  110. ground_truth[user] = 1.0
  111. num_ground_truth += 1
  112. pred_cnt = np.zeros([numuser])
  113. for beid in range(opt.beam_size):
  114. # print(all_samples[0,0,0:1])
  115. for j in range(num_ground_truth):
  116. # print(bid,beid,j)
  117. pred_uid = all_samples[bid, beid, j]
  118. if pred_uid == Constants.EOS:
  119. break
  120. else:
  121. pred_cnt[pred_uid] += 1.0 / opt.beam_size
  122. F1, pre, rec = getF1(ground_truth, pred_cnt)
  123. avgF1 += F1
  124. avgPre += pre
  125. avgRec += rec
  126. right += np.dot(ground_truth, pred_cnt)
  127. pred += np.sum(pred_cnt)
  128. total += np.sum(ground_truth)
  129. # for short user
  130. ground_truth = np.zeros([numuser])
  131. num_ground_truth = 0
  132. for user in trg_y.data[bid][1:-1]:
  133. if user == Constants.EOS or user == Constants.PAD:
  134. break
  135. ground_truth[user] = 1.0
  136. num_ground_truth += 1
  137. if num_ground_truth >= 5:
  138. break
  139. pred_cnt = np.zeros([numuser])
  140. for beid in range(opt.beam_size):
  141. # total += len(ground_truth)
  142. for j in range(num_ground_truth):
  143. pred_uid = all_samples[bid, beid, j]
  144. if pred_uid == Constants.EOS:
  145. break
  146. # continue
  147. else:
  148. pred_cnt[pred_uid] += 1.0 / opt.beam_size
  149. F1, pre, rec = getF1(ground_truth, pred_cnt)
  150. avgF1_short += F1
  151. avgPre_short += pre
  152. avgRec_short += rec
  153. right_short += np.dot(ground_truth, pred_cnt)
  154. pred_short += np.sum(pred_cnt)
  155. total_short += np.sum(ground_truth)
  156. print(avgF1, avgPre, avgRec)
  157. print('[Info] Finished.')
  158. print('Macro')
  159. print(avgF1 / numseq)
  160. print(avgPre / numseq)
  161. print(avgRec / numseq)
  162. print('Results for the first no more than 5 predictions')
  163. print(avgF1_short / numseq)
  164. print(avgPre_short / numseq)
  165. print(avgRec_short / numseq)
  166. print('Micro')
  167. pmi = right / pred
  168. rmi = right / total
  169. print(2 * pmi * rmi / (pmi + rmi))
  170. print(pmi)
  171. print(rmi)
  172. print('Results for the first no more than 5 predictions')
  173. pmi_long = right_short / pred_short
  174. rmi_long = right_short / total_short
  175. print(2 * pmi_long * rmi_long / (pmi_long + rmi_long))
  176. print(pmi_long)
  177. print(rmi_long)
  178. if __name__ == "__main__":
  179. main()