Browse Source

statistical analysis added

master
Faeze 1 year ago
parent
commit
1b80a9a519
8 changed files with 1148 additions and 135 deletions
  1. 5
    1
      .gitignore
  2. 5
    10
      data/data_loader.py
  3. 1118
    0
      data/twitter/EDA.ipynb
  4. 1
    2
      data_loaders.py
  5. 8
    0
      evaluation.py
  6. 0
    110
      lime_main.py
  7. 1
    1
      model.py
  8. 10
    11
      test_main.py

+ 5
- 1
.gitignore View File

@@ -117,4 +117,8 @@ dmypy.json
.idea/
env/
venv/
results/
results/
*.jpg
*.csv
*.json
*.png

+ 5
- 10
data/data_loader.py View File

@@ -1,7 +1,7 @@
import torch
from torch.utils.data import Dataset
import albumentations as A
from transformers import ViTFeatureExtractor
from transformers import ViTFeatureExtractor, AutoTokenizer
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer


@@ -16,7 +16,7 @@ def get_transforms(config):

def get_tokenizer(config):
if 'roberta' in config.text_encoder_model:
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
@@ -28,14 +28,9 @@ class DatasetLoader(Dataset):
def __init__(self, config, dataframe, mode):
self.config = config
self.mode = mode
if mode == 'lime':
self.image_filenames = [dataframe["image"],]
self.text = [dataframe["text"],]
self.labels = [dataframe["label"],]
else:
self.image_filenames = dataframe["image"].values
self.text = list(dataframe["text"].values)
self.labels = dataframe["label"].values
self.image_filenames = dataframe["image"].values
self.text = list(dataframe["text"].values)
self.labels = dataframe["label"].values

tokenizer = get_tokenizer(config)
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt')

+ 1118
- 0
data/twitter/EDA.ipynb
File diff suppressed because it is too large
View File


+ 1
- 2
data_loaders.py View File

@@ -19,7 +19,6 @@ def make_dfs(config):
else:
test_dataframe = pd.read_csv(config.test_text_path)
test_dataframe.dropna(subset=['text'], inplace=True)
test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True)
test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))

if config.validation_text_path is None:
@@ -40,7 +39,7 @@ def build_loaders(config, dataframe, mode):
if mode != 'train':
dataloader = DataLoader(
dataset,
batch_size=config.batch_size // 2 if mode == 'lime' else 1,
batch_size=config.batch_size // 2,
num_workers=config.num_workers // 2,
pin_memory=False,
shuffle=False,

+ 8
- 0
evaluation.py View File

@@ -257,6 +257,14 @@ def save_embedding(x, fname='embedding.tsv'):
embedding_df.to_csv(fname, sep='\t', index=False, header=False)


def save_2D_embedding(x, fname=''):
x = [[i.cpu().numpy() for i in j] for j in x]

for i, batch in enumerate(x):
embedding_df = pd.DataFrame(batch)
embedding_df.to_csv(fname + '/batch ' + str(i) + '.csv', sep=',')


def plot_pca(config, x, y, fname='pca.png'):
x = [i.cpu().numpy() for i in x]
y = [i.cpu().numpy() for i in y]

+ 0
- 110
lime_main.py View File

@@ -1,110 +0,0 @@
import random

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from learner import batch_constructor
from model import FakeNewsModel

from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
from lime.lime_text import LimeTextExplainer

from utils import make_directory


def lime_(config, test_loader, model):
for i, batch in enumerate(test_loader):
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)

score = score.detach().cpu().numpy()

logit = model.logits.detach().cpu().numpy()

return score, logit


def lime_main(config, trial_number=None):
_, test_df, _ = make_dfs(config, )
test_df = test_df[:1]

if trial_number:
try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
else:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))

try:
parameters = checkpoint['parameters']
config.assign_hyperparameters(parameters)
except:
pass

model = FakeNewsModel(config).to(config.device)
try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
model.load_state_dict(checkpoint)

model.eval()

torch.manual_seed(27)
random.seed(27)
np.random.seed(27)

text_explainer = LimeTextExplainer(class_names=config.classes)
image_explainer = LimeImageExplainer()

make_directory(config.output_path + '/lime/')
make_directory(config.output_path + '/lime/score/')
make_directory(config.output_path + '/lime/logit/')
make_directory(config.output_path + '/lime/text/')
make_directory(config.output_path + '/lime/image/')

def text_predict_proba(text):
scores = []
print(len(text))
for i in text:
row['text'] = i
test_loader = build_loaders(config, row, mode="lime")
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

def image_predict_proba(image):
scores = []
print(len(image))
for i in image:
test_loader = build_loaders(config, row, mode="lime")
test_loader['image'] = i.reshape((3, 224, 224))
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

for i, row in test_df.iterrows():
test_loader = build_loaders(config, row, mode="lime")
score, logit = lime_(config, test_loader, model)
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",")
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",")

text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5)
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html')
print('text', i, 'finished')

data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0)
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba,
top_labels=2, hide_color=0, num_samples=1)
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5,
hide_rest=False)
img_boundry = mark_boundaries(temp / 255.0, mask)
plt.imshow(img_boundry)
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png')
print('image', i, 'finished')

+ 1
- 1
model.py View File

@@ -19,7 +19,7 @@ class ProjectionHead(nn.Module):
self.gelu = nn.GELU()
self.fc = nn.Linear(config.projection_size, config.projection_size)
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.BatchNorm1d(config.projection_size)
self.layer_norm = nn.LayerNorm(config.projection_size)

def forward(self, x):
projected = self.projection(x)

+ 10
- 11
test_main.py View File

@@ -6,8 +6,7 @@ import torch
from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \
save_embedding, save_loss
from evaluation import *
from learner import batch_constructor
from model import FakeNewsModel, calculate_loss

@@ -68,24 +67,24 @@ def test(config, test_loader, trial_number=None):

s = ''
s += report_per_class(targets, predictions) + '\n'
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n'
with open(config.output_path + '/new_results.txt', 'w') as f:
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n'
with open(config.output_path + '/results.txt', 'w') as f:
f.write(s)

roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")

save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv')
save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv')
save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
save_embedding(similarities, fname=str(config.output_path) + '/similarities.tsv')
save_embedding(image_features, fname=str(config.output_path) + '/image_features.tsv')
save_embedding(text_features, fname=str(config.output_path) + '/text_features.tsv')
save_embedding(multimodal_features, fname=str(config.output_path) + '/multimodal_features_.tsv')
save_embedding(concat_features, fname=str(config.output_path) + '/concat_features.tsv')
save_2D_embedding(similarities, fname=str(config.output_path))

config_parameters = str(config)
with open(config.output_path + '/new_parameters.txt', 'w') as f:
with open(config.output_path + '/parameters.txt', 'w') as f:
f.write(config_parameters)

save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv')
save_loss(ids, predictions, targets, losses, str(config.output_path) + '/text_label.csv')




Loading…
Cancel
Save