Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 9.8KB

10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from itertools import cycle
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import pandas as pd
  5. import seaborn as sns
  6. from numpy import interp
  7. from sklearn.decomposition import PCA
  8. from sklearn.manifold import TSNE
  9. from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score
  10. from sklearn.metrics import classification_report, roc_curve, auc
  11. from sklearn.preprocessing import OneHotEncoder
  12. def metrics(truth, pred, prob, file_path):
  13. truth = [i.cpu().numpy() for i in truth]
  14. pred = [i.cpu().numpy() for i in pred]
  15. prob = [i.cpu().numpy() for i in prob]
  16. pred = np.concatenate(pred, axis=0)
  17. truth = np.concatenate(truth, axis=0)
  18. prob = np.concatenate(prob, axis=0)
  19. prob = prob[:, 1]
  20. f_score_micro = f1_score(truth, pred, average='micro', zero_division=0)
  21. f_score_macro = f1_score(truth, pred, average='macro', zero_division=0)
  22. f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0)
  23. accuarcy = accuracy_score(truth, pred)
  24. s = ''
  25. print('accuracy', accuarcy)
  26. s += 'accuracy' + str(accuarcy) + '\n'
  27. print('f_score_micro', f_score_micro)
  28. s += 'f_score_micro' + str(f_score_micro) + '\n'
  29. print('f_score_macro', f_score_macro)
  30. s += 'f_score_macro' + str(f_score_macro) + '\n'
  31. print('f_score_weighted', f_score_weighted)
  32. s += 'f_score_weighted' + str(f_score_weighted) + '\n'
  33. fpr, tpr, thresholds = roc_curve(truth, prob)
  34. AUC = auc(fpr, tpr)
  35. print('AUC', AUC)
  36. s += 'AUC' + str(AUC) + '\n'
  37. df = pd.DataFrame(dict(fpr=fpr, tpr=tpr))
  38. df.to_csv(file_path)
  39. return s
  40. def report_per_class(truth, pred):
  41. truth = [i.cpu().numpy() for i in truth]
  42. pred = [i.cpu().numpy() for i in pred]
  43. pred = np.concatenate(pred, axis=0)
  44. truth = np.concatenate(truth, axis=0)
  45. report = classification_report(truth, pred, zero_division=0, output_dict=True)
  46. s = ''
  47. class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']]
  48. for class_label in class_labels:
  49. print('class_label', class_label)
  50. s += 'class_label' + str(class_label) + '\n'
  51. s += str(report[class_label])
  52. print(report[class_label])
  53. return s
  54. def multiclass_acc(truth, pred):
  55. truth = [i.cpu().numpy() for i in truth]
  56. pred = [i.cpu().numpy() for i in pred]
  57. pred = np.concatenate(pred, axis=0)
  58. truth = np.concatenate(truth, axis=0)
  59. return accuracy_score(truth, pred)
  60. def roc_auc_plot(truth, score, num_class=2, fname='roc.png'):
  61. truth = [i.cpu().numpy() for i in truth]
  62. score = [i.cpu().numpy() for i in score]
  63. truth = np.concatenate(truth, axis=0)
  64. score = np.concatenate(score, axis=0)
  65. enc = OneHotEncoder(handle_unknown='ignore')
  66. enc.fit(truth.reshape(-1, 1))
  67. label_onehot = enc.transform(truth.reshape(-1, 1)).toarray()
  68. fpr_dict = dict()
  69. tpr_dict = dict()
  70. roc_auc_dict = dict()
  71. for i in range(num_class):
  72. fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i])
  73. roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
  74. # micro
  75. fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel())
  76. roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
  77. # macro
  78. all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
  79. mean_tpr = np.zeros_like(all_fpr)
  80. for i in range(num_class):
  81. mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
  82. mean_tpr /= num_class
  83. fpr_dict["macro"] = all_fpr
  84. tpr_dict["macro"] = mean_tpr
  85. roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
  86. plt.figure()
  87. lw = 2
  88. plt.plot(fpr_dict["micro"], tpr_dict["micro"],
  89. label='micro-average ROC curve (area = {0:0.2f})'
  90. ''.format(roc_auc_dict["micro"]),
  91. color='deeppink', linestyle=':', linewidth=4)
  92. plt.plot(fpr_dict["macro"], tpr_dict["macro"],
  93. label='macro-average ROC curve (area = {0:0.2f})'
  94. ''.format(roc_auc_dict["macro"]),
  95. color='navy', linestyle=':', linewidth=4)
  96. colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
  97. for i, color in zip(range(num_class), colors):
  98. plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
  99. label='ROC curve of class {0} (area = {1:0.2f})'
  100. ''.format(i, roc_auc_dict[i]))
  101. plt.plot([0, 1], [0, 1], 'k--', lw=lw)
  102. plt.xlim([0.0, 1.0])
  103. plt.ylim([0.0, 1.05])
  104. plt.xlabel('False Positive Rate')
  105. plt.ylabel('True Positive Rate')
  106. plt.legend(loc="lower right")
  107. plt.savefig(fname)
  108. # plt.show()
  109. def precision_recall_plot(truth, score, num_class=2, fname='pr.png'):
  110. truth = [i.cpu().numpy() for i in truth]
  111. score = [i.cpu().numpy() for i in score]
  112. truth = np.concatenate(truth, axis=0)
  113. score = np.concatenate(score, axis=0)
  114. enc = OneHotEncoder(handle_unknown='ignore')
  115. enc.fit(truth.reshape(-1, 1))
  116. label_onehot = enc.transform(truth.reshape(-1, 1)).toarray()
  117. # Call the Sklearn library, calculate the precision and recall corresponding to each category
  118. precision_dict = dict()
  119. recall_dict = dict()
  120. average_precision_dict = dict()
  121. for i in range(num_class):
  122. precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i])
  123. average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i])
  124. print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i])
  125. # micro
  126. precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(),
  127. score.ravel())
  128. average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro")
  129. # macro
  130. all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)]))
  131. mean_tpr = np.zeros_like(all_fpr)
  132. for i in range(num_class):
  133. mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i])
  134. mean_tpr /= num_class
  135. precision_dict["macro"] = all_fpr
  136. recall_dict["macro"] = mean_tpr
  137. average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"])
  138. plt.figure()
  139. plt.subplots(figsize=(16, 10))
  140. lw = 2
  141. plt.plot(precision_dict["micro"], recall_dict["micro"],
  142. label='micro-average Precision-Recall curve (area = {0:0.2f})'
  143. ''.format(average_precision_dict["micro"]),
  144. color='deeppink', linestyle=':', linewidth=4)
  145. plt.plot(precision_dict["macro"], recall_dict["macro"],
  146. label='macro-average Precision-Recall curve (area = {0:0.2f})'
  147. ''.format(average_precision_dict["macro"]),
  148. color='navy', linestyle=':', linewidth=4)
  149. colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
  150. for i, color in zip(range(num_class), colors):
  151. plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw,
  152. label='Precision-Recall curve of class {0} (area = {1:0.2f})'
  153. ''.format(i, average_precision_dict[i]))
  154. plt.plot([0, 1], [0, 1], 'k--', lw=lw)
  155. plt.xlabel('Recall')
  156. plt.ylabel('Precision')
  157. plt.ylim([0.0, 1.05])
  158. plt.xlim([0.0, 1.0])
  159. plt.legend(loc="lower left")
  160. plt.savefig(fname=fname)
  161. # plt.show()
  162. def saving_in_tensorboard(config, x, y, fname='embedding'):
  163. x = [i.cpu().numpy() for i in x]
  164. y = [i.cpu().numpy() for i in y]
  165. x = np.concatenate(x, axis=0)
  166. y = np.concatenate(y, axis=0)
  167. z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values
  168. # config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname)
  169. def plot_tsne(config, x, y, fname='tsne.png'):
  170. x = [i.cpu().numpy() for i in x]
  171. y = [i.cpu().numpy() for i in y]
  172. x = np.concatenate(x, axis=0)
  173. y = np.concatenate(y, axis=0)
  174. y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values
  175. tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000)
  176. tsne_proj = tsne.fit_transform(x)
  177. fig, ax = plt.subplots(figsize=(16, 10))
  178. palette = sns.color_palette("bright", 2)
  179. sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette)
  180. ax.legend(fontsize='large', markerscale=2)
  181. plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0]))
  182. plt.savefig(fname=fname)
  183. plt.show()
  184. def save_loss(ids, predictions, targets, l, path):
  185. ids = [i.cpu().numpy() for i in ids]
  186. predictions = [i.cpu().numpy() for i in predictions]
  187. targets = [i.cpu().numpy() for i in targets]
  188. losses = [i[0].cpu().numpy() for i in l]
  189. classifier_losses = [i[1].cpu().numpy() for i in l]
  190. similarity_losses = [i[2].cpu().numpy() for i in l]
  191. pd.DataFrame({'id': ids, 'predicted_label': predictions, 'real_label': targets, 'losses': losses,
  192. 'classifier_losses': classifier_losses, 'similarity_losses': similarity_losses}).to_csv(path)
  193. def save_embedding(x, fname='embedding.tsv'):
  194. x = [i.cpu().numpy() for i in x]
  195. x = np.concatenate(x, axis=0)
  196. embedding_df = pd.DataFrame(x)
  197. embedding_df.to_csv(fname, sep='\t', index=False, header=False)
  198. def plot_pca(config, x, y, fname='pca.png'):
  199. x = [i.cpu().numpy() for i in x]
  200. y = [i.cpu().numpy() for i in y]
  201. x = np.concatenate(x, axis=0)
  202. y = np.concatenate(y, axis=0)
  203. y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values
  204. pca = PCA(n_components=2)
  205. pca_proj = pca.fit_transform(x)
  206. fig, ax = plt.subplots(figsize=(16, 10))
  207. palette = sns.color_palette("bright", 2)
  208. sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette)
  209. ax.legend(fontsize='large', markerscale=2)
  210. plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0]))
  211. plt.savefig(fname=fname)
  212. # plt.show()