You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. ''' Translate input text with trained model. '''
  2. import torch
  3. import random
  4. import argparse
  5. from tqdm import tqdm
  6. import numpy as np
  7. from transformer.Translator_beam import Translator
  8. # from transformer.Translator import Translator
  9. import transformer.Constants as Constants
  10. import pickle
  11. from DataLoader import DataLoader
  12. from sklearn.neighbors import NearestNeighbors
  13. idx2vec_addr = '/home/rafie/NeuralDiffusionModel-master/data/irvine/idx2vec.pickle'
  14. CUDA = 1
  15. def getF1(ground_truth, pred_cnt):
  16. right = np.dot(ground_truth, pred_cnt)
  17. pred = np.sum(pred_cnt)
  18. total = np.sum(ground_truth)
  19. print(right, pred, total)
  20. if pred == 0:
  21. return 0, 0, 0
  22. precision = right / pred
  23. recall = right / total
  24. if precision == 0 or recall == 0:
  25. return 0, 0, 0
  26. return (2 * precision * recall) / (precision + recall), precision, recall
  27. def getMSE(ground_truth, pred_cnt):
  28. return mean_squared_error(ground_truth, pred_cnt)
  29. def getMAP(ground_truth, pred_cnt):
  30. size_cascade = list(ground_truth).count(1)
  31. avgp = 0.
  32. ite_list = [[idx, item, ground_truth[idx]] for idx, item in enumerate(pred_cnt)]
  33. ite_list = sorted(ite_list, key=lambda x: x[1], reverse=True)
  34. n_positive = 0
  35. idx = 0
  36. while n_positive < size_cascade:
  37. if ite_list[idx][2] == 1:
  38. n_positive += 1
  39. pre = n_positive / (idx + 1.)
  40. avgp = avgp + pre
  41. idx += 1
  42. # if idx >= 1:
  43. # break
  44. avgp = avgp / size_cascade
  45. return avgp
  46. def ranking(seq, candidates, n, idx2vec):
  47. seq_user = np.array([idx2vec[int(idx.data.cpu().numpy())] for idx in seq[0]])
  48. candidates_user = [idx2vec[int(idx.data.cpu().numpy())] for idx in candidates[0]]
  49. nbrs = NearestNeighbors(n_neighbors=100, algorithm='ball_tree').fit(candidates_user)
  50. print('************', candidates_user, seq_user, np.mean(seq_user, axis=0))
  51. distances, indices = nbrs.kneighbors([np.mean(seq_user, axis=0)])
  52. output = [candidates[0][index] for index in indices]
  53. return output
  54. def main():
  55. '''Main Function'''
  56. parser = argparse.ArgumentParser(description='translate.py')
  57. parser.add_argument('-model', default="Irvine.chkpt", required=False,
  58. help='Path to model .pt file')
  59. '''
  60. parser.add_argument('-src', required=True,
  61. help='Source sequence to decode (one line per sequence)')
  62. parser.add_argument('-vocab', required=True,
  63. help='Source sequence to decode (one line per sequence)')
  64. '''
  65. parser.add_argument('-output', default='pred.txt',
  66. help="""Path to output the predictions (each line will
  67. be the decoded sequence""")
  68. parser.add_argument('-beam_size', type=int, default=100,
  69. help='Beam size')
  70. parser.add_argument('-batch_size', type=int, default=1,
  71. help='Batch size')
  72. parser.add_argument('-n_best', type=int, default=1,
  73. help="""If verbose is set, will output the n_best
  74. decoded sentences""")
  75. parser.add_argument('-no_cuda', action='store_true')
  76. # parser.add_argument('-seq_len', type=int ,default=20)
  77. parser.add_argument('-pred_len', type=int, default=15)
  78. opt = parser.parse_args()
  79. opt.cuda = not opt.no_cuda
  80. # Prepare DataLoader
  81. test_data = DataLoader(use_valid=True, batch_size=opt.batch_size, cuda=opt.cuda)
  82. translator = Translator(opt)
  83. translator.model.eval()
  84. numuser = test_data.user_size
  85. num_right = 0
  86. num_total = 0
  87. avgF1 = 0
  88. avgPre = 0
  89. avgRec = 0
  90. avgF1_best = 0
  91. avgPre_best = 0
  92. avgRec_best = 0
  93. avgF1_long = 0
  94. avgPre_long = 0
  95. avgRec_long = 0
  96. avgF1_short = 0
  97. avgPre_short = 0
  98. avgRec_short = 0
  99. numseq = 0 # number of test seqs
  100. # for micro pre rec f1
  101. right = 0.
  102. pred = 0.
  103. total = 0.
  104. rigth_best = 0.
  105. pred_best = 0.
  106. total_best = 0.
  107. right_long = 0.
  108. pred_long = 0.
  109. total_long = 0.
  110. right_short = 0.
  111. pred_short = 0.
  112. total_short = 0.
  113. with open(idx2vec_addr, 'rb') as handle:
  114. idx2vec = pickle.load(handle)
  115. with open(opt.output, 'w') as f:
  116. for batch in tqdm(test_data, mininterval=2, desc=' - (Test)', leave=False):
  117. n_best = opt.pred_len
  118. opt.pred_len = batch.size(1) - 10
  119. opt.seq_len = 10
  120. all_samples = translator.translate_batch(batch[:, 0:opt.seq_len], opt.pred_len).data
  121. candidates = all_samples.view(1, -1)
  122. # best_candidates = ranking(batch.data[:, 0:opt.seq_len] , candidates , n_best , idx2vec)
  123. for bid in range(batch.size(0)):
  124. numseq += 1.0
  125. ground_truth = np.zeros([numuser])
  126. num_ground_truth = 0
  127. for user in batch.data[bid][1 + opt.seq_len:-1]:
  128. if user == Constants.EOS or user == Constants.PAD:
  129. break
  130. ground_truth[user] = 1.0
  131. num_ground_truth += 1
  132. pred_cnt = np.zeros([numuser])
  133. for beid in range(opt.beam_size):
  134. for pred_uid in all_samples[bid, beid, 1 + opt.seq_len:num_ground_truth + 1]:
  135. if pred_uid == Constants.EOS:
  136. break
  137. else:
  138. pred_cnt[pred_uid] += 1.0 / opt.beam_size
  139. # pred_cnt_best = np.zeros([numuser])
  140. # for user in best_candidates:
  141. # if user != Constants.EOS:
  142. # pred_cnt_best[user] += 1.0
  143. F1, pre, rec = getF1(ground_truth, pred_cnt)
  144. avgF1 += F1
  145. avgPre += pre
  146. avgRec += rec
  147. right += np.dot(ground_truth, pred_cnt)
  148. pred += np.sum(pred_cnt)
  149. total += np.sum(ground_truth)
  150. # F1, pre, rec = getF1(ground_truth, pred_cnt_best)
  151. # avgF1_best += F1
  152. # avgPre_best += pre
  153. # avgRec_best += rec
  154. # right_best += np.dot(ground_truth, pred_cnt_best)
  155. # pred_best += np.sum(pred_cnt_best)
  156. # total_best += np.sum(ground_truth)
  157. print(avgF1, avgPre, avgRec)
  158. print('[Info] Finished.')
  159. print('Macro')
  160. print(avgF1 / numseq)
  161. print(avgPre / numseq)
  162. print(avgRec / numseq)
  163. print('Micro')
  164. pmi = right / pred
  165. rmi = right / total
  166. print(2 * pmi * rmi / (pmi + rmi))
  167. print(pmi)
  168. print(rmi)
  169. # print('[Info] Finished.')
  170. # print('Macro')
  171. # print(avgF1_best / numseq)
  172. # print(avgPre_best / numseq)
  173. # print(avgRec_best / numseq)
  174. # print('Micro')
  175. # pmi = right_best / pred_best
  176. # rmi = right_best / total_best
  177. # print(2 * pmi * rmi / (pmi + rmi))
  178. # print(pmi)
  179. # print(rmi)
  180. if __name__ == "__main__":
  181. main()