Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lime_main.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import random
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. from matplotlib import pyplot as plt
  6. from tqdm import tqdm
  7. from data_loaders import make_dfs, build_loaders
  8. from learner import batch_constructor
  9. from model import FakeNewsModel
  10. from lime.lime_image import LimeImageExplainer
  11. from skimage.segmentation import mark_boundaries
  12. from lime.lime_text import LimeTextExplainer
  13. from utils import make_directory
  14. def lime_(config, test_loader, model):
  15. for i, batch in enumerate(test_loader):
  16. batch = batch_constructor(config, batch)
  17. with torch.no_grad():
  18. output, score = model(batch)
  19. score = score.detach().cpu().numpy()
  20. logit = model.logits.detach().cpu().numpy()
  21. return score, logit
  22. def lime_main(config, trial_number=None):
  23. _, test_df, _ = make_dfs(config, )
  24. test_df = test_df[:1]
  25. if trial_number:
  26. try:
  27. checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
  28. except:
  29. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
  30. else:
  31. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))
  32. try:
  33. parameters = checkpoint['parameters']
  34. config.assign_hyperparameters(parameters)
  35. except:
  36. pass
  37. model = FakeNewsModel(config).to(config.device)
  38. try:
  39. model.load_state_dict(checkpoint['model_state_dict'])
  40. except:
  41. model.load_state_dict(checkpoint)
  42. model.eval()
  43. torch.manual_seed(27)
  44. random.seed(27)
  45. np.random.seed(27)
  46. text_explainer = LimeTextExplainer(class_names=config.classes)
  47. image_explainer = LimeImageExplainer()
  48. make_directory(config.output_path + '/lime/')
  49. make_directory(config.output_path + '/lime/score/')
  50. make_directory(config.output_path + '/lime/logit/')
  51. make_directory(config.output_path + '/lime/text/')
  52. make_directory(config.output_path + '/lime/image/')
  53. def text_predict_proba(text):
  54. scores = []
  55. print(len(text))
  56. for i in text:
  57. row['text'] = i
  58. test_loader = build_loaders(config, row, mode="lime")
  59. score, _ = lime_(config, test_loader, model)
  60. scores.append(score)
  61. return np.array(scores)
  62. def image_predict_proba(image):
  63. scores = []
  64. print(len(image))
  65. for i in image:
  66. test_loader = build_loaders(config, row, mode="lime")
  67. test_loader['image'] = i.reshape((3, 224, 224))
  68. score, _ = lime_(config, test_loader, model)
  69. scores.append(score)
  70. return np.array(scores)
  71. for i, row in test_df.iterrows():
  72. test_loader = build_loaders(config, row, mode="lime")
  73. score, logit = lime_(config, test_loader, model)
  74. np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",")
  75. np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",")
  76. text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5)
  77. text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html')
  78. print('text', i, 'finished')
  79. data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0)
  80. img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba,
  81. top_labels=2, hide_color=0, num_samples=1)
  82. temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5,
  83. hide_rest=False)
  84. img_boundry = mark_boundaries(temp / 255.0, mask)
  85. plt.imshow(img_boundry)
  86. plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png')
  87. print('image', i, 'finished')