|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import random
-
- import numpy as np
- import pandas as pd
- import torch
- from matplotlib import pyplot as plt
- from tqdm import tqdm
-
- from data_loaders import make_dfs, build_loaders
- from learner import batch_constructor
- from model import FakeNewsModel
-
- from lime.lime_image import LimeImageExplainer
- from skimage.segmentation import mark_boundaries
- from lime.lime_text import LimeTextExplainer
-
- from utils import make_directory
-
-
- def lime_(config, test_loader, model):
- for i, batch in enumerate(test_loader):
- batch = batch_constructor(config, batch)
- with torch.no_grad():
- output, score = model(batch)
-
- score = score.detach().cpu().numpy()
-
- logit = model.logits.detach().cpu().numpy()
-
- return score, logit
-
-
- def lime_main(config, trial_number=None):
- _, test_df, _ = make_dfs(config, )
- test_df = test_df[:1]
-
- if trial_number:
- try:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
- except:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
- else:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))
-
- try:
- parameters = checkpoint['parameters']
- config.assign_hyperparameters(parameters)
- except:
- pass
-
- model = FakeNewsModel(config).to(config.device)
- try:
- model.load_state_dict(checkpoint['model_state_dict'])
- except:
- model.load_state_dict(checkpoint)
-
- model.eval()
-
- torch.manual_seed(27)
- random.seed(27)
- np.random.seed(27)
-
- text_explainer = LimeTextExplainer(class_names=config.classes)
- image_explainer = LimeImageExplainer()
-
- make_directory(config.output_path + '/lime/')
- make_directory(config.output_path + '/lime/score/')
- make_directory(config.output_path + '/lime/logit/')
- make_directory(config.output_path + '/lime/text/')
- make_directory(config.output_path + '/lime/image/')
-
- def text_predict_proba(text):
- scores = []
- print(len(text))
- for i in text:
- row['text'] = i
- test_loader = build_loaders(config, row, mode="lime")
- score, _ = lime_(config, test_loader, model)
- scores.append(score)
- return np.array(scores)
-
- def image_predict_proba(image):
- scores = []
- print(len(image))
- for i in image:
- test_loader = build_loaders(config, row, mode="lime")
- test_loader['image'] = i.reshape((3, 224, 224))
- score, _ = lime_(config, test_loader, model)
- scores.append(score)
- return np.array(scores)
-
- for i, row in test_df.iterrows():
- test_loader = build_loaders(config, row, mode="lime")
- score, logit = lime_(config, test_loader, model)
- np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",")
- np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",")
-
- text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5)
- text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html')
- print('text', i, 'finished')
-
- data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0)
- img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba,
- top_labels=2, hide_color=0, num_samples=1)
- temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5,
- hide_rest=False)
- img_boundry = mark_boundaries(temp / 255.0, mask)
- plt.imshow(img_boundry)
- plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png')
- print('image', i, 'finished')
|