@@ -117,4 +117,8 @@ dmypy.json | |||
.idea/ | |||
env/ | |||
venv/ | |||
results/ | |||
results/ | |||
*.jpg | |||
*.csv | |||
*.json | |||
*.png |
@@ -1,7 +1,7 @@ | |||
import torch | |||
from torch.utils.data import Dataset | |||
import albumentations as A | |||
from transformers import ViTFeatureExtractor | |||
from transformers import ViTFeatureExtractor, AutoTokenizer | |||
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||
@@ -16,7 +16,7 @@ def get_transforms(config): | |||
def get_tokenizer(config): | |||
if 'roberta' in config.text_encoder_model: | |||
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||
tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer) | |||
elif 'xlnet' in config.text_encoder_model: | |||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
else: | |||
@@ -28,14 +28,9 @@ class DatasetLoader(Dataset): | |||
def __init__(self, config, dataframe, mode): | |||
self.config = config | |||
self.mode = mode | |||
if mode == 'lime': | |||
self.image_filenames = [dataframe["image"],] | |||
self.text = [dataframe["text"],] | |||
self.labels = [dataframe["label"],] | |||
else: | |||
self.image_filenames = dataframe["image"].values | |||
self.text = list(dataframe["text"].values) | |||
self.labels = dataframe["label"].values | |||
self.image_filenames = dataframe["image"].values | |||
self.text = list(dataframe["text"].values) | |||
self.labels = dataframe["label"].values | |||
tokenizer = get_tokenizer(config) | |||
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') |
@@ -19,7 +19,6 @@ def make_dfs(config): | |||
else: | |||
test_dataframe = pd.read_csv(config.test_text_path) | |||
test_dataframe.dropna(subset=['text'], inplace=True) | |||
test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||
test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
if config.validation_text_path is None: | |||
@@ -40,7 +39,7 @@ def build_loaders(config, dataframe, mode): | |||
if mode != 'train': | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||
batch_size=config.batch_size // 2, | |||
num_workers=config.num_workers // 2, | |||
pin_memory=False, | |||
shuffle=False, |
@@ -257,6 +257,14 @@ def save_embedding(x, fname='embedding.tsv'): | |||
embedding_df.to_csv(fname, sep='\t', index=False, header=False) | |||
def save_2D_embedding(x, fname=''): | |||
x = [[i.cpu().numpy() for i in j] for j in x] | |||
for i, batch in enumerate(x): | |||
embedding_df = pd.DataFrame(batch) | |||
embedding_df.to_csv(fname + '/batch ' + str(i) + '.csv', sep=',') | |||
def plot_pca(config, x, y, fname='pca.png'): | |||
x = [i.cpu().numpy() for i in x] | |||
y = [i.cpu().numpy() for i in y] |
@@ -1,110 +0,0 @@ | |||
import random | |||
import numpy as np | |||
import pandas as pd | |||
import torch | |||
from matplotlib import pyplot as plt | |||
from tqdm import tqdm | |||
from data_loaders import make_dfs, build_loaders | |||
from learner import batch_constructor | |||
from model import FakeNewsModel | |||
from lime.lime_image import LimeImageExplainer | |||
from skimage.segmentation import mark_boundaries | |||
from lime.lime_text import LimeTextExplainer | |||
from utils import make_directory | |||
def lime_(config, test_loader, model): | |||
for i, batch in enumerate(test_loader): | |||
batch = batch_constructor(config, batch) | |||
with torch.no_grad(): | |||
output, score = model(batch) | |||
score = score.detach().cpu().numpy() | |||
logit = model.logits.detach().cpu().numpy() | |||
return score, logit | |||
def lime_main(config, trial_number=None): | |||
_, test_df, _ = make_dfs(config, ) | |||
test_df = test_df[:1] | |||
if trial_number: | |||
try: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
except: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
else: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
try: | |||
parameters = checkpoint['parameters'] | |||
config.assign_hyperparameters(parameters) | |||
except: | |||
pass | |||
model = FakeNewsModel(config).to(config.device) | |||
try: | |||
model.load_state_dict(checkpoint['model_state_dict']) | |||
except: | |||
model.load_state_dict(checkpoint) | |||
model.eval() | |||
torch.manual_seed(27) | |||
random.seed(27) | |||
np.random.seed(27) | |||
text_explainer = LimeTextExplainer(class_names=config.classes) | |||
image_explainer = LimeImageExplainer() | |||
make_directory(config.output_path + '/lime/') | |||
make_directory(config.output_path + '/lime/score/') | |||
make_directory(config.output_path + '/lime/logit/') | |||
make_directory(config.output_path + '/lime/text/') | |||
make_directory(config.output_path + '/lime/image/') | |||
def text_predict_proba(text): | |||
scores = [] | |||
print(len(text)) | |||
for i in text: | |||
row['text'] = i | |||
test_loader = build_loaders(config, row, mode="lime") | |||
score, _ = lime_(config, test_loader, model) | |||
scores.append(score) | |||
return np.array(scores) | |||
def image_predict_proba(image): | |||
scores = [] | |||
print(len(image)) | |||
for i in image: | |||
test_loader = build_loaders(config, row, mode="lime") | |||
test_loader['image'] = i.reshape((3, 224, 224)) | |||
score, _ = lime_(config, test_loader, model) | |||
scores.append(score) | |||
return np.array(scores) | |||
for i, row in test_df.iterrows(): | |||
test_loader = build_loaders(config, row, mode="lime") | |||
score, logit = lime_(config, test_loader, model) | |||
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||
text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||
print('text', i, 'finished') | |||
data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||
top_labels=2, hide_color=0, num_samples=1) | |||
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||
hide_rest=False) | |||
img_boundry = mark_boundaries(temp / 255.0, mask) | |||
plt.imshow(img_boundry) | |||
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||
print('image', i, 'finished') |
@@ -19,7 +19,7 @@ class ProjectionHead(nn.Module): | |||
self.gelu = nn.GELU() | |||
self.fc = nn.Linear(config.projection_size, config.projection_size) | |||
self.dropout = nn.Dropout(config.dropout) | |||
self.layer_norm = nn.BatchNorm1d(config.projection_size) | |||
self.layer_norm = nn.LayerNorm(config.projection_size) | |||
def forward(self, x): | |||
projected = self.projection(x) |
@@ -6,8 +6,7 @@ import torch | |||
from tqdm import tqdm | |||
from data_loaders import make_dfs, build_loaders | |||
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \ | |||
save_embedding, save_loss | |||
from evaluation import * | |||
from learner import batch_constructor | |||
from model import FakeNewsModel, calculate_loss | |||
@@ -68,24 +67,24 @@ def test(config, test_loader, trial_number=None): | |||
s = '' | |||
s += report_per_class(targets, predictions) + '\n' | |||
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n' | |||
with open(config.output_path + '/new_results.txt', 'w') as f: | |||
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||
with open(config.output_path + '/results.txt', 'w') as f: | |||
f.write(s) | |||
roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||
save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv') | |||
save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv') | |||
save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv') | |||
save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv') | |||
save_embedding(similarities, fname=str(config.output_path) + '/similarities.tsv') | |||
save_embedding(image_features, fname=str(config.output_path) + '/image_features.tsv') | |||
save_embedding(text_features, fname=str(config.output_path) + '/text_features.tsv') | |||
save_embedding(multimodal_features, fname=str(config.output_path) + '/multimodal_features_.tsv') | |||
save_embedding(concat_features, fname=str(config.output_path) + '/concat_features.tsv') | |||
save_2D_embedding(similarities, fname=str(config.output_path)) | |||
config_parameters = str(config) | |||
with open(config.output_path + '/new_parameters.txt', 'w') as f: | |||
with open(config.output_path + '/parameters.txt', 'w') as f: | |||
f.write(config_parameters) | |||
save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv') | |||
save_loss(ids, predictions, targets, losses, str(config.output_path) + '/text_label.csv') | |||