| @@ -117,4 +117,8 @@ dmypy.json | |||
| .idea/ | |||
| env/ | |||
| venv/ | |||
| results/ | |||
| results/ | |||
| *.jpg | |||
| *.csv | |||
| *.json | |||
| *.png | |||
| @@ -1,7 +1,7 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| import albumentations as A | |||
| from transformers import ViTFeatureExtractor | |||
| from transformers import ViTFeatureExtractor, AutoTokenizer | |||
| from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||
| @@ -16,7 +16,7 @@ def get_transforms(config): | |||
| def get_tokenizer(config): | |||
| if 'roberta' in config.text_encoder_model: | |||
| tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||
| tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
| else: | |||
| @@ -28,14 +28,9 @@ class DatasetLoader(Dataset): | |||
| def __init__(self, config, dataframe, mode): | |||
| self.config = config | |||
| self.mode = mode | |||
| if mode == 'lime': | |||
| self.image_filenames = [dataframe["image"],] | |||
| self.text = [dataframe["text"],] | |||
| self.labels = [dataframe["label"],] | |||
| else: | |||
| self.image_filenames = dataframe["image"].values | |||
| self.text = list(dataframe["text"].values) | |||
| self.labels = dataframe["label"].values | |||
| self.image_filenames = dataframe["image"].values | |||
| self.text = list(dataframe["text"].values) | |||
| self.labels = dataframe["label"].values | |||
| tokenizer = get_tokenizer(config) | |||
| self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') | |||
| @@ -19,7 +19,6 @@ def make_dfs(config): | |||
| else: | |||
| test_dataframe = pd.read_csv(config.test_text_path) | |||
| test_dataframe.dropna(subset=['text'], inplace=True) | |||
| test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||
| test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
| if config.validation_text_path is None: | |||
| @@ -40,7 +39,7 @@ def build_loaders(config, dataframe, mode): | |||
| if mode != 'train': | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||
| batch_size=config.batch_size // 2, | |||
| num_workers=config.num_workers // 2, | |||
| pin_memory=False, | |||
| shuffle=False, | |||
| @@ -257,6 +257,14 @@ def save_embedding(x, fname='embedding.tsv'): | |||
| embedding_df.to_csv(fname, sep='\t', index=False, header=False) | |||
| def save_2D_embedding(x, fname=''): | |||
| x = [[i.cpu().numpy() for i in j] for j in x] | |||
| for i, batch in enumerate(x): | |||
| embedding_df = pd.DataFrame(batch) | |||
| embedding_df.to_csv(fname + '/batch ' + str(i) + '.csv', sep=',') | |||
| def plot_pca(config, x, y, fname='pca.png'): | |||
| x = [i.cpu().numpy() for i in x] | |||
| y = [i.cpu().numpy() for i in y] | |||
| @@ -1,110 +0,0 @@ | |||
| import random | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch | |||
| from matplotlib import pyplot as plt | |||
| from tqdm import tqdm | |||
| from data_loaders import make_dfs, build_loaders | |||
| from learner import batch_constructor | |||
| from model import FakeNewsModel | |||
| from lime.lime_image import LimeImageExplainer | |||
| from skimage.segmentation import mark_boundaries | |||
| from lime.lime_text import LimeTextExplainer | |||
| from utils import make_directory | |||
| def lime_(config, test_loader, model): | |||
| for i, batch in enumerate(test_loader): | |||
| batch = batch_constructor(config, batch) | |||
| with torch.no_grad(): | |||
| output, score = model(batch) | |||
| score = score.detach().cpu().numpy() | |||
| logit = model.logits.detach().cpu().numpy() | |||
| return score, logit | |||
| def lime_main(config, trial_number=None): | |||
| _, test_df, _ = make_dfs(config, ) | |||
| test_df = test_df[:1] | |||
| if trial_number: | |||
| try: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
| except: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
| else: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
| try: | |||
| parameters = checkpoint['parameters'] | |||
| config.assign_hyperparameters(parameters) | |||
| except: | |||
| pass | |||
| model = FakeNewsModel(config).to(config.device) | |||
| try: | |||
| model.load_state_dict(checkpoint['model_state_dict']) | |||
| except: | |||
| model.load_state_dict(checkpoint) | |||
| model.eval() | |||
| torch.manual_seed(27) | |||
| random.seed(27) | |||
| np.random.seed(27) | |||
| text_explainer = LimeTextExplainer(class_names=config.classes) | |||
| image_explainer = LimeImageExplainer() | |||
| make_directory(config.output_path + '/lime/') | |||
| make_directory(config.output_path + '/lime/score/') | |||
| make_directory(config.output_path + '/lime/logit/') | |||
| make_directory(config.output_path + '/lime/text/') | |||
| make_directory(config.output_path + '/lime/image/') | |||
| def text_predict_proba(text): | |||
| scores = [] | |||
| print(len(text)) | |||
| for i in text: | |||
| row['text'] = i | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| score, _ = lime_(config, test_loader, model) | |||
| scores.append(score) | |||
| return np.array(scores) | |||
| def image_predict_proba(image): | |||
| scores = [] | |||
| print(len(image)) | |||
| for i in image: | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| test_loader['image'] = i.reshape((3, 224, 224)) | |||
| score, _ = lime_(config, test_loader, model) | |||
| scores.append(score) | |||
| return np.array(scores) | |||
| for i, row in test_df.iterrows(): | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| score, logit = lime_(config, test_loader, model) | |||
| np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||
| np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||
| text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||
| text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||
| print('text', i, 'finished') | |||
| data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||
| img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||
| top_labels=2, hide_color=0, num_samples=1) | |||
| temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||
| hide_rest=False) | |||
| img_boundry = mark_boundaries(temp / 255.0, mask) | |||
| plt.imshow(img_boundry) | |||
| plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||
| print('image', i, 'finished') | |||
| @@ -19,7 +19,7 @@ class ProjectionHead(nn.Module): | |||
| self.gelu = nn.GELU() | |||
| self.fc = nn.Linear(config.projection_size, config.projection_size) | |||
| self.dropout = nn.Dropout(config.dropout) | |||
| self.layer_norm = nn.BatchNorm1d(config.projection_size) | |||
| self.layer_norm = nn.LayerNorm(config.projection_size) | |||
| def forward(self, x): | |||
| projected = self.projection(x) | |||
| @@ -6,8 +6,7 @@ import torch | |||
| from tqdm import tqdm | |||
| from data_loaders import make_dfs, build_loaders | |||
| from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \ | |||
| save_embedding, save_loss | |||
| from evaluation import * | |||
| from learner import batch_constructor | |||
| from model import FakeNewsModel, calculate_loss | |||
| @@ -68,24 +67,24 @@ def test(config, test_loader, trial_number=None): | |||
| s = '' | |||
| s += report_per_class(targets, predictions) + '\n' | |||
| s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n' | |||
| with open(config.output_path + '/new_results.txt', 'w') as f: | |||
| s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||
| with open(config.output_path + '/results.txt', 'w') as f: | |||
| f.write(s) | |||
| roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||
| precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||
| save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv') | |||
| save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv') | |||
| save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv') | |||
| save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv') | |||
| save_embedding(similarities, fname=str(config.output_path) + '/similarities.tsv') | |||
| save_embedding(image_features, fname=str(config.output_path) + '/image_features.tsv') | |||
| save_embedding(text_features, fname=str(config.output_path) + '/text_features.tsv') | |||
| save_embedding(multimodal_features, fname=str(config.output_path) + '/multimodal_features_.tsv') | |||
| save_embedding(concat_features, fname=str(config.output_path) + '/concat_features.tsv') | |||
| save_2D_embedding(similarities, fname=str(config.output_path)) | |||
| config_parameters = str(config) | |||
| with open(config.output_path + '/new_parameters.txt', 'w') as f: | |||
| with open(config.output_path + '/parameters.txt', 'w') as f: | |||
| f.write(config_parameters) | |||
| save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv') | |||
| save_loss(ids, predictions, targets, losses, str(config.output_path) + '/text_label.csv') | |||