# Pyre type checker | # Pyre type checker | ||||
.pyre/ | .pyre/ | ||||
.idea/ | |||||
env/ | |||||
venv/ |
# FakeNewsRevealer | # FakeNewsRevealer | ||||
Official implementation of the Fake News Revealer paper | |||||
Official implementation of the "Fake News Revealer (FNR): A Similarity and Transformer-Based Approach to Detect | |||||
Multi-Modal Fake News in Social Media" paper [(ArXiv Link)](https://arxiv.org/pdf/2112.01131.pdf). | |||||
## Requirments | |||||
FNR is built in Python 3.6 using PyTorch 1.8. Please use the following command to install the requirements: | |||||
``` | |||||
pip install -r requirements.txt | |||||
``` | |||||
## How to Run | |||||
First, place the data address and configuration into the config file in the data directory, and then follow the train | |||||
and test commands. | |||||
### train | |||||
To run with Optuna for parameter tuning use this command: | |||||
``` | |||||
python main --data "DATA NAME" --use_optuna "NUMBER OF OPTUNA TRIALS" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||||
``` | |||||
To run without parameter tuning, adjust your parameters in the config file and then use the below command: | |||||
``` | |||||
python main --data "DATA NAME" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||||
``` | |||||
### test | |||||
In the test step, at first, make sure to have the requested 'checkpoint' file then run the following line: | |||||
``` | |||||
python main --data "DATA NAME" --just_test "REQESTED TRIAL NUMBER" | |||||
``` | |||||
## Bibtex | |||||
Cite our paper using the following bibtex item: | |||||
``` | |||||
@misc{ghorbanpour2021fnr, | |||||
title={FNR: A Similarity and Transformer-Based Approach to Detect Multi-Modal Fake News in Social Media}, | |||||
author={Faeze Ghorbanpour and Maryam Ramezani and Mohammad A. Fazli and Hamid R. Rabiee}, | |||||
year={2021}, | |||||
eprint={2112.01131}, | |||||
archivePrefix={arXiv}, | |||||
} | |||||
``` | |||||
import torch | |||||
import ast | |||||
class Config: | |||||
name = '' | |||||
DatasetLoader = None | |||||
debug = False | |||||
data_path = '' | |||||
output_path = '' | |||||
train_image_path = '' | |||||
validation_image_path = '' | |||||
test_image_path = '' | |||||
train_text_path = '' | |||||
validation_text_path = '' | |||||
test_text_path = '' | |||||
train_text_embedding_path = '' | |||||
validation_text_embedding_path = '' | |||||
train_image_embedding_path = '' | |||||
validation_image_embedding_path = '' | |||||
batch_size = 256 | |||||
epochs = 50 | |||||
num_workers = 2 | |||||
head_lr = 1e-03 | |||||
image_encoder_lr = 1e-04 | |||||
text_encoder_lr = 1e-04 | |||||
attention_lr = 1e-3 | |||||
classification_lr = 1e-03 | |||||
max_grad_norm = 5.0 | |||||
head_weight_decay = 0.001 | |||||
attention_weight_decay = 0.001 | |||||
classification_weight_decay = 0.001 | |||||
patience = 30 | |||||
delta = 0.0000001 | |||||
factor = 0.8 | |||||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
image_embedding = 768 | |||||
num_img_region = 64 # 16 #TODO | |||||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
text_embedding = 768 | |||||
max_length = 32 | |||||
pretrained = True # for both image encoder and text encoder | |||||
trainable = False # for both image encoder and text encoder | |||||
temperature = 1.0 | |||||
# image size | |||||
size = 224 | |||||
num_projection_layers = 1 | |||||
projection_size = 256 | |||||
dropout = 0.3 | |||||
hidden_size = 256 | |||||
num_region = 64 # 16 #TODO | |||||
region_size = projection_size // num_region | |||||
classes = ['real', 'fake'] | |||||
class_num = 2 | |||||
loss_weight = 1 | |||||
class_weights = [1, 1] | |||||
writer = None | |||||
has_unlabeled_data = False | |||||
step = 0 | |||||
T1 = 10 | |||||
T2 = 150 | |||||
af = 3 | |||||
wanted_accuracy = 0.76 | |||||
def optuna(self, trial): | |||||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1) | |||||
self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
def __str__(self): | |||||
s = '' | |||||
members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] | |||||
for member in members: | |||||
s += member + '\t' + str(getattr(self, member)) + '\n' | |||||
return s | |||||
def assign_hyperparameters(self, s): | |||||
for line in s.split('\n'): | |||||
s = line.split('\t') | |||||
try: | |||||
attr = getattr(self, s[0]) | |||||
if type(attr) not in [list, set, dict]: | |||||
setattr(self, s[0], type(attr)(s[1])) | |||||
else: | |||||
setattr(self, s[0], ast.literal_eval(s[1])) | |||||
except: | |||||
print(s[0], 'doesnot exist') |
import torch | |||||
from torch.utils.data import Dataset | |||||
class DatasetLoader(Dataset): | |||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
self.config = config | |||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer( | |||||
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||||
) | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | |||||
def set_text(self, idx): | |||||
item = { | |||||
key: values[idx].clone().detach() | |||||
for key, values in self.encoded_text.items() | |||||
} | |||||
item['text'] = self.text[idx] | |||||
item['label'] = self.labels[idx] | |||||
item['id'] = idx | |||||
return item | |||||
def set_image(self, image, item): | |||||
if 'resnet' in self.config.image_model_name: | |||||
image = self.transforms(image=image)['image'] | |||||
item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||||
else: | |||||
image = self.transforms(images=image, return_tensors='pt') | |||||
image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||||
item['image'] = image.reshape((3, 224, 224)) |
import torch | |||||
from data.config import Config | |||||
from data.twitter.data_loader import TwitterDatasetLoader | |||||
class TwitterConfig(Config): | |||||
name = 'twitter' | |||||
DatasetLoader = TwitterDatasetLoader | |||||
data_path = '/twitter/' | |||||
output_path = '' | |||||
train_image_path = data_path + 'images_train/' | |||||
validation_image_path = data_path + 'images_validation/' | |||||
test_image_path = data_path + 'images_test/' | |||||
train_text_path = data_path + 'twitter_train_translated.csv' | |||||
validation_text_path = data_path + 'twitter_validation_translated.csv' | |||||
test_text_path = data_path + 'twitter_test_translated.csv' | |||||
batch_size = 128 | |||||
epochs = 100 | |||||
num_workers = 2 | |||||
head_lr = 1e-03 | |||||
image_encoder_lr = 1e-04 | |||||
text_encoder_lr = 1e-04 | |||||
attention_lr = 1e-3 | |||||
classification_lr = 1e-03 | |||||
head_weight_decay = 0.001 | |||||
attention_weight_decay = 0.001 | |||||
classification_weight_decay = 0.001 | |||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
image_model_name = 'vit-base-patch16-224' | |||||
image_embedding = 768 | |||||
text_encoder_model = "bert-base-uncased" | |||||
text_tokenizer = "bert-base-uncased" | |||||
text_embedding = 768 | |||||
max_length = 32 | |||||
pretrained = True | |||||
trainable = False | |||||
temperature = 1.0 | |||||
classes = ['real', 'fake'] | |||||
class_weights = [1, 1] | |||||
wanted_accuracy = 0.76 | |||||
def optuna(self, trial): | |||||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
# self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) |
import cv2 | |||||
import pickle | |||||
from PIL import Image | |||||
from numpy import asarray | |||||
from data.data_loader import DatasetLoader | |||||
class TwitterDatasetLoader(DatasetLoader): | |||||
def __getitem__(self, idx): | |||||
item = self.set_text(idx) | |||||
try: | |||||
if self.mode == 'train': | |||||
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||||
elif self.mode == 'validation': | |||||
image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}") | |||||
else: | |||||
image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}") | |||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||||
except: | |||||
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
image = asarray(image) | |||||
self.set_image(image, item) | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) | |||||
class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
self.config = config | |||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | |||||
if mode == 'train': | |||||
with open(config.train_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.train_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
else: | |||||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
def __getitem__(self, idx): | |||||
item = dict() | |||||
item['id'] = idx | |||||
item['image_embedding'] = self.image_embedding[idx] | |||||
item['text_embedding'] = self.text_embedding[idx] | |||||
item['label'] = self.labels[idx] | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) |
import torch | |||||
from transformers import BertTokenizer, BertModel, BertConfig | |||||
from data.config import Config | |||||
from data.weibo.data_loader import WeiboDatasetLoader | |||||
class WeiboConfig(Config): | |||||
name = 'weibo' | |||||
DatasetLoader = WeiboDatasetLoader | |||||
data_path = 'weibo/' | |||||
output_path = '' | |||||
rumor_image_path = data_path + 'rumor_images/' | |||||
nonrumor_image_path = data_path + 'nonrumor_images/' | |||||
train_text_path = data_path + 'weibo_train.csv' | |||||
validation_text_path = data_path + 'weibo_validation.csv' | |||||
test_text_path = data_path + 'weibo_test.csv' | |||||
batch_size = 128 | |||||
epochs = 100 | |||||
num_workers = 2 | |||||
head_lr = 1e-03 | |||||
image_encoder_lr = 1e-02 | |||||
text_encoder_lr = 1e-05 | |||||
weight_decay = 0.001 | |||||
classification_lr = 1e-02 | |||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||||
image_embedding = 768 # 2048 | |||||
num_img_region = 64 # TODO | |||||
text_encoder_model = "bert-base-chinese" | |||||
text_tokenizer = "bert-base-chinese" | |||||
text_embedding = 768 | |||||
max_length = 200 | |||||
pretrained = True | |||||
trainable = False | |||||
temperature = 1.0 | |||||
labels = ['real', 'fake'] | |||||
wanted_accuracy = 0.80 | |||||
def optuna(self, trial): | |||||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
import pickle | |||||
import cv2 | |||||
from PIL import Image | |||||
from numpy import asarray | |||||
from data.data_loader import DatasetLoader | |||||
class WeiboDatasetLoader(DatasetLoader): | |||||
def __getitem__(self, idx): | |||||
item = self.set_text(idx) | |||||
if self.labels[idx] == 1: | |||||
try: | |||||
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||||
except: | |||||
image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
image = asarray(image) | |||||
else: | |||||
try: | |||||
image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}") | |||||
except: | |||||
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
image = asarray(image) | |||||
self.set_image(image, item) | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) | |||||
class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
self.config = config | |||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | |||||
if mode == 'train': | |||||
with open(config.train_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.train_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
else: | |||||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
def __getitem__(self, idx): | |||||
item = dict() | |||||
item['id'] = idx | |||||
item['image_embedding'] = self.image_embedding[idx] | |||||
item['text_embedding'] = self.text_embedding[idx] | |||||
item['label'] = self.labels[idx] | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) |
import albumentations as A | |||||
import numpy as np | |||||
import pandas as pd | |||||
from sklearn.utils import compute_class_weight | |||||
from torch.utils.data import DataLoader | |||||
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
def get_transforms(config, mode="train"): | |||||
if mode == "train": | |||||
return A.Compose( | |||||
[ | |||||
A.Resize(config.size, config.size, always_apply=True), | |||||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
] | |||||
) | |||||
else: | |||||
return A.Compose( | |||||
[ | |||||
A.Resize(config.size, config.size, always_apply=True), | |||||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
] | |||||
) | |||||
def make_dfs(config): | |||||
train_dataframe = pd.read_csv(config.train_text_path) | |||||
train_dataframe.dropna(subset=['text'], inplace=True) | |||||
train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True) | |||||
train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
config.class_weights = get_class_weights(train_dataframe.label.values) | |||||
if config.test_text_path is None: | |||||
offset = int(train_dataframe.shape[0] * 0.80) | |||||
test_dataframe = train_dataframe[offset:] | |||||
train_dataframe = train_dataframe[:offset] | |||||
else: | |||||
test_dataframe = pd.read_csv(config.test_text_path) | |||||
test_dataframe.dropna(subset=['text'], inplace=True) | |||||
test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||||
test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
if config.validation_text_path is None: | |||||
offset = int(train_dataframe.shape[0] * 0.90) | |||||
validation_dataframe = train_dataframe[offset:] | |||||
train_dataframe = train_dataframe[:offset] | |||||
else: | |||||
validation_dataframe = pd.read_csv(config.validation_text_path) | |||||
validation_dataframe.dropna(subset=['text'], inplace=True) | |||||
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||||
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
unlabeled_dataframe = None | |||||
if config.has_unlabeled_data: | |||||
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||||
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||||
return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||||
def get_tokenizer(config): | |||||
if 'bigbird' in config.text_encoder_model: | |||||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
elif 'xlnet' in config.text_encoder_model: | |||||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
else: | |||||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
return tokenizer | |||||
def build_loaders(config, dataframe, mode): | |||||
if 'resnet' in config.image_model_name: | |||||
transforms = get_transforms(config, mode=mode) | |||||
else: | |||||
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
dataset = config.DatasetLoader( | |||||
config, | |||||
dataframe["image"].values, | |||||
dataframe["text"].values, | |||||
dataframe["label"].values, | |||||
tokenizer=get_tokenizer(config), | |||||
transforms=transforms, | |||||
mode=mode, | |||||
) | |||||
if mode != 'train': | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_size=config.batch_size, | |||||
num_workers=config.num_workers, | |||||
pin_memory=False, | |||||
shuffle=False, | |||||
) | |||||
else: | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
batch_size=config.batch_size, | |||||
num_workers=config.num_workers, | |||||
pin_memory=True, | |||||
shuffle=True, | |||||
) | |||||
return dataloader | |||||
def get_class_weights(y): | |||||
class_weights = compute_class_weight('balanced', np.unique(y), y) | |||||
return class_weights |
from itertools import cycle | |||||
import matplotlib.pyplot as plt | |||||
import numpy as np | |||||
import pandas as pd | |||||
import seaborn as sns | |||||
from numpy import interp | |||||
from sklearn.decomposition import PCA | |||||
from sklearn.manifold import TSNE | |||||
from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score | |||||
from sklearn.metrics import classification_report, roc_curve, auc | |||||
from sklearn.preprocessing import OneHotEncoder | |||||
def metrics(truth, pred, prob, file_path): | |||||
truth = [i.cpu().numpy() for i in truth] | |||||
pred = [i.cpu().numpy() for i in pred] | |||||
prob = [i.cpu().numpy() for i in prob] | |||||
pred = np.concatenate(pred, axis=0) | |||||
truth = np.concatenate(truth, axis=0) | |||||
prob = np.concatenate(prob, axis=0) | |||||
prob = prob[:, 1] | |||||
f_score_micro = f1_score(truth, pred, average='micro', zero_division=0) | |||||
f_score_macro = f1_score(truth, pred, average='macro', zero_division=0) | |||||
f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0) | |||||
accuarcy = accuracy_score(truth, pred) | |||||
s = '' | |||||
print('accuracy', accuarcy) | |||||
s += 'accuracy' + str(accuarcy) + '\n' | |||||
print('f_score_micro', f_score_micro) | |||||
s += 'f_score_micro' + str(f_score_micro) + '\n' | |||||
print('f_score_macro', f_score_macro) | |||||
s += 'f_score_macro' + str(f_score_macro) + '\n' | |||||
print('f_score_weighted', f_score_weighted) | |||||
s += 'f_score_weighted' + str(f_score_weighted) + '\n' | |||||
fpr, tpr, thresholds = roc_curve(truth, prob) | |||||
AUC = auc(fpr, tpr) | |||||
print('AUC', AUC) | |||||
s += 'AUC' + str(AUC) + '\n' | |||||
df = pd.DataFrame(dict(fpr=fpr, tpr=tpr)) | |||||
df.to_csv(file_path) | |||||
return s | |||||
def report_per_class(truth, pred): | |||||
truth = [i.cpu().numpy() for i in truth] | |||||
pred = [i.cpu().numpy() for i in pred] | |||||
pred = np.concatenate(pred, axis=0) | |||||
truth = np.concatenate(truth, axis=0) | |||||
report = classification_report(truth, pred, zero_division=0, output_dict=True) | |||||
s = '' | |||||
class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']] | |||||
for class_label in class_labels: | |||||
print('class_label', class_label) | |||||
s += 'class_label' + str(class_label) + '\n' | |||||
s += str(report[class_label]) | |||||
print(report[class_label]) | |||||
return s | |||||
def multiclass_acc(truth, pred): | |||||
truth = [i.cpu().numpy() for i in truth] | |||||
pred = [i.cpu().numpy() for i in pred] | |||||
pred = np.concatenate(pred, axis=0) | |||||
truth = np.concatenate(truth, axis=0) | |||||
return accuracy_score(truth, pred) | |||||
def roc_auc_plot(truth, score, num_class=2, fname='roc.png'): | |||||
truth = [i.cpu().numpy() for i in truth] | |||||
score = [i.cpu().numpy() for i in score] | |||||
truth = np.concatenate(truth, axis=0) | |||||
score = np.concatenate(score, axis=0) | |||||
enc = OneHotEncoder(handle_unknown='ignore') | |||||
enc.fit(truth.reshape(-1, 1)) | |||||
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||||
fpr_dict = dict() | |||||
tpr_dict = dict() | |||||
roc_auc_dict = dict() | |||||
for i in range(num_class): | |||||
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i]) | |||||
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) | |||||
# micro | |||||
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel()) | |||||
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) | |||||
# macro | |||||
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)])) | |||||
mean_tpr = np.zeros_like(all_fpr) | |||||
for i in range(num_class): | |||||
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) | |||||
mean_tpr /= num_class | |||||
fpr_dict["macro"] = all_fpr | |||||
tpr_dict["macro"] = mean_tpr | |||||
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) | |||||
plt.figure() | |||||
lw = 2 | |||||
plt.plot(fpr_dict["micro"], tpr_dict["micro"], | |||||
label='micro-average ROC curve (area = {0:0.2f})' | |||||
''.format(roc_auc_dict["micro"]), | |||||
color='deeppink', linestyle=':', linewidth=4) | |||||
plt.plot(fpr_dict["macro"], tpr_dict["macro"], | |||||
label='macro-average ROC curve (area = {0:0.2f})' | |||||
''.format(roc_auc_dict["macro"]), | |||||
color='navy', linestyle=':', linewidth=4) | |||||
colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||||
for i, color in zip(range(num_class), colors): | |||||
plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, | |||||
label='ROC curve of class {0} (area = {1:0.2f})' | |||||
''.format(i, roc_auc_dict[i])) | |||||
plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||||
plt.xlim([0.0, 1.0]) | |||||
plt.ylim([0.0, 1.05]) | |||||
plt.xlabel('False Positive Rate') | |||||
plt.ylabel('True Positive Rate') | |||||
plt.legend(loc="lower right") | |||||
plt.savefig(fname) | |||||
# plt.show() | |||||
def precision_recall_plot(truth, score, num_class=2, fname='pr.png'): | |||||
truth = [i.cpu().numpy() for i in truth] | |||||
score = [i.cpu().numpy() for i in score] | |||||
truth = np.concatenate(truth, axis=0) | |||||
score = np.concatenate(score, axis=0) | |||||
enc = OneHotEncoder(handle_unknown='ignore') | |||||
enc.fit(truth.reshape(-1, 1)) | |||||
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||||
# Call the Sklearn library, calculate the precision and recall corresponding to each category | |||||
precision_dict = dict() | |||||
recall_dict = dict() | |||||
average_precision_dict = dict() | |||||
for i in range(num_class): | |||||
precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i]) | |||||
average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i]) | |||||
print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i]) | |||||
# micro | |||||
precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(), | |||||
score.ravel()) | |||||
average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro") | |||||
# macro | |||||
all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)])) | |||||
mean_tpr = np.zeros_like(all_fpr) | |||||
for i in range(num_class): | |||||
mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i]) | |||||
mean_tpr /= num_class | |||||
precision_dict["macro"] = all_fpr | |||||
recall_dict["macro"] = mean_tpr | |||||
average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"]) | |||||
plt.figure() | |||||
plt.subplots(figsize=(16, 10)) | |||||
lw = 2 | |||||
plt.plot(precision_dict["micro"], recall_dict["micro"], | |||||
label='micro-average Precision-Recall curve (area = {0:0.2f})' | |||||
''.format(average_precision_dict["micro"]), | |||||
color='deeppink', linestyle=':', linewidth=4) | |||||
plt.plot(precision_dict["macro"], recall_dict["macro"], | |||||
label='macro-average Precision-Recall curve (area = {0:0.2f})' | |||||
''.format(average_precision_dict["macro"]), | |||||
color='navy', linestyle=':', linewidth=4) | |||||
colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||||
for i, color in zip(range(num_class), colors): | |||||
plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw, | |||||
label='Precision-Recall curve of class {0} (area = {1:0.2f})' | |||||
''.format(i, average_precision_dict[i])) | |||||
plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||||
plt.xlabel('Recall') | |||||
plt.ylabel('Precision') | |||||
plt.ylim([0.0, 1.05]) | |||||
plt.xlim([0.0, 1.0]) | |||||
plt.legend(loc="lower left") | |||||
plt.savefig(fname=fname) | |||||
# plt.show() | |||||
def saving_in_tensorboard(config, x, y, fname='embedding'): | |||||
x = [i.cpu().numpy() for i in x] | |||||
y = [i.cpu().numpy() for i in y] | |||||
x = np.concatenate(x, axis=0) | |||||
y = np.concatenate(y, axis=0) | |||||
z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
# config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname) | |||||
def plot_tsne(config, x, y, fname='tsne.png'): | |||||
x = [i.cpu().numpy() for i in x] | |||||
y = [i.cpu().numpy() for i in y] | |||||
x = np.concatenate(x, axis=0) | |||||
y = np.concatenate(y, axis=0) | |||||
y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000) | |||||
tsne_proj = tsne.fit_transform(x) | |||||
fig, ax = plt.subplots(figsize=(16, 10)) | |||||
palette = sns.color_palette("bright", 2) | |||||
sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette) | |||||
ax.legend(fontsize='large', markerscale=2) | |||||
plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0])) | |||||
plt.savefig(fname=fname) | |||||
# plt.show() | |||||
def plot_pca(config, x, y, fname='pca.png'): | |||||
x = [i.cpu().numpy() for i in x] | |||||
y = [i.cpu().numpy() for i in y] | |||||
x = np.concatenate(x, axis=0) | |||||
y = np.concatenate(y, axis=0) | |||||
y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
pca = PCA(n_components=2) | |||||
pca_proj = pca.fit_transform(x) | |||||
fig, ax = plt.subplots(figsize=(16, 10)) | |||||
palette = sns.color_palette("bright", 2) | |||||
sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette) | |||||
ax.legend(fontsize='large', markerscale=2) | |||||
plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0])) | |||||
plt.savefig(fname=fname) | |||||
# plt.show() |
import timm | |||||
from torch import nn | |||||
from transformers import ViTModel, ViTConfig | |||||
class ImageTransformerEncoder(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
if config.pretrained: | |||||
self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False, | |||||
output_hidden_states=True, return_dict=True) | |||||
else: | |||||
self.model = ViTModel(config=ViTConfig()) | |||||
for p in self.model.parameters(): | |||||
p.requires_grad = config.trainable | |||||
self.target_token_idx = 0 | |||||
self.image_encoder_embedding = dict() | |||||
def forward(self, ids, image): | |||||
output = self.model(image) | |||||
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||||
# for i, id in enumerate(ids): | |||||
# id = int(id.detach().cpu().numpy()) | |||||
# self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||||
return last_hidden_state | |||||
class ImageResnetEncoder(nn.Module): | |||||
def __init__( | |||||
self, config | |||||
): | |||||
super().__init__() | |||||
self.model = timm.create_model( | |||||
config.image_model_name, config.pretrained, num_classes=0, global_pool="avg" | |||||
) | |||||
for p in self.model.parameters(): | |||||
p.requires_grad = config.trainable | |||||
def forward(self, ids, image): | |||||
return self.model(image) |
import gc | |||||
import itertools | |||||
import random | |||||
import numpy as np | |||||
import optuna | |||||
import pandas as pd | |||||
import torch | |||||
from evaluation import multiclass_acc | |||||
from model import FakeNewsModel, calculate_loss | |||||
from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving | |||||
def batch_constructor(config, batch): | |||||
b = {} | |||||
for k, v in batch.items(): | |||||
if k != 'text': | |||||
b[k] = v.to(config.device) | |||||
return b | |||||
def train_epoch(config, model, train_loader, optimizer, scalar): | |||||
loss_meter = AvgMeter('train') | |||||
# tqdm_object = tqdm(train_loader, total=len(train_loader)) | |||||
targets = [] | |||||
predictions = [] | |||||
for index, batch in enumerate(train_loader): | |||||
batch = batch_constructor(config, batch) | |||||
optimizer.zero_grad(set_to_none=True) | |||||
with torch.cuda.amp.autocast(): | |||||
output, score = model(batch) | |||||
loss = calculate_loss(model, score, batch['label']) | |||||
scalar.scale(loss).backward() | |||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |||||
if (index + 1) % 2: | |||||
scalar.step(optimizer) | |||||
# loss.backward() | |||||
# optimizer.step() | |||||
scalar.update() | |||||
count = batch["id"].size(0) | |||||
loss_meter.update(loss.detach(), count) | |||||
# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |||||
prediction = output.detach() | |||||
predictions.append(prediction) | |||||
target = batch['label'].detach() | |||||
targets.append(target) | |||||
return loss_meter, targets, predictions | |||||
def validation_epoch(config, model, validation_loader): | |||||
loss_meter = AvgMeter('validation') | |||||
targets = [] | |||||
predictions = [] | |||||
# tqdm_object = tqdm(validation_loader, total=len(validation_loader)) | |||||
for batch in validation_loader: | |||||
batch = batch_constructor(config, batch) | |||||
with torch.no_grad(): | |||||
output, score = model(batch) | |||||
loss = calculate_loss(model, score, batch['label']) | |||||
count = batch["id"].size(0) | |||||
loss_meter.update(loss.detach(), count) | |||||
# tqdm_object.set_postfix(validation_loss=loss_meter.avg) | |||||
prediction = output.detach() | |||||
predictions.append(prediction) | |||||
target = batch['label'].detach() | |||||
targets.append(target) | |||||
return loss_meter, targets, predictions | |||||
def supervised_train(config, train_loader, validation_loader, trial=None): | |||||
torch.cuda.empty_cache() | |||||
checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt' | |||||
if trial: | |||||
checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt' | |||||
torch.manual_seed(27) | |||||
random.seed(27) | |||||
np.random.seed(27) | |||||
torch.autograd.set_detect_anomaly(False) | |||||
torch.autograd.profiler.profile(False) | |||||
torch.autograd.profiler.emit_nvtx(False) | |||||
scalar = torch.cuda.amp.GradScaler() | |||||
model = FakeNewsModel(config).to(config.device) | |||||
params = [ | |||||
{"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'}, | |||||
{"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'}, | |||||
{"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), | |||||
"lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'}, | |||||
{"params": model.classifier.parameters(), "lr": config.classification_lr, | |||||
"weight_decay": config.classification_weight_decay, | |||||
'name': 'classifier'} | |||||
] | |||||
optimizer = torch.optim.AdamW(params, amsgrad=True) | |||||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor, | |||||
patience=config.patience // 5, verbose=True) | |||||
early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True) | |||||
checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True) | |||||
train_losses, train_accuracies = [], [] | |||||
validation_losses, validation_accuracies = [], [] | |||||
validation_accuracy, validation_loss = 0, 1 | |||||
for epoch in range(config.epochs): | |||||
print(f"Epoch: {epoch + 1}") | |||||
gc.collect() | |||||
model.train() | |||||
train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar) | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||||
train_accuracy = multiclass_acc(train_truth, train_pred) | |||||
validation_accuracy = multiclass_acc(validation_truth, validation_pred) | |||||
print_lr(optimizer) | |||||
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy) | |||||
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) | |||||
if lr_scheduler: | |||||
lr_scheduler.step(validation_loss.avg) | |||||
if early_stopping: | |||||
early_stopping(validation_loss.avg, model) | |||||
if early_stopping.early_stop: | |||||
print("Early stopping") | |||||
break | |||||
if checkpoint_saving: | |||||
checkpoint_saving(validation_accuracy, model) | |||||
train_accuracies.append(train_accuracy) | |||||
train_losses.append(train_loss) | |||||
validation_accuracies.append(validation_accuracy) | |||||
validation_losses.append(validation_loss) | |||||
if trial: | |||||
trial.report(validation_accuracy, epoch) | |||||
if trial.should_prune(): | |||||
print('trial pruned') | |||||
raise optuna.exceptions.TrialPruned() | |||||
print() | |||||
if checkpoint_saving: | |||||
model = FakeNewsModel(config).to(config.device) | |||||
model.load_state_dict(torch.load(checkpoint_path)) | |||||
model.eval() | |||||
with torch.no_grad(): | |||||
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||||
validation_accuracy = multiclass_acc(validation_pred, validation_truth) | |||||
if trial and validation_accuracy >= config.wanted_accuracy: | |||||
loss_accuracy = pd.DataFrame( | |||||
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||||
'validation_accuracy': validation_accuracies}) | |||||
torch.save({'model_state_dict': model.state_dict(), | |||||
'parameters': str(config), | |||||
'optimizer_state_dict': optimizer.state_dict(), | |||||
'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||||
if not checkpoint_saving: | |||||
loss_accuracy = pd.DataFrame( | |||||
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||||
'validation_accuracy': validation_accuracies}) | |||||
torch.save(model.state_dict(), checkpoint_path) | |||||
if trial and validation_accuracy >= config.wanted_accuracy: | |||||
torch.save({'model_state_dict': model.state_dict(), | |||||
'parameters': str(config), | |||||
'optimizer_state_dict': optimizer.state_dict(), | |||||
'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||||
try: | |||||
del train_loss | |||||
del train_truth | |||||
del train_pred | |||||
del validation_loss | |||||
del validation_truth | |||||
del validation_pred | |||||
del train_losses | |||||
del train_accuracies | |||||
del validation_losses | |||||
del validation_accuracies | |||||
del loss_accuracy | |||||
del scalar | |||||
del model | |||||
del params | |||||
except: | |||||
print('Error in deleting caches') | |||||
pass | |||||
return validation_accuracy | |||||
import argparse | |||||
from optuna_main import optuna_main | |||||
from test_main import test_main | |||||
from torch_main import torch_main | |||||
from data.twitter.config import TwitterConfig | |||||
from data.weibo.config import WeiboConfig | |||||
from torch.utils.tensorboard import SummaryWriter | |||||
if __name__ == '__main__': | |||||
import os | |||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--data', type=str, required=True) | |||||
parser.add_argument('--use_optuna', type=int, required=False) | |||||
parser.add_argument('--just_test', type=int, required=False) | |||||
parser.add_argument('--batch', type=int, required=False) | |||||
parser.add_argument('--epoch', type=int, required=False) | |||||
parser.add_argument('--extra', type=str, required=False) | |||||
args = parser.parse_args() | |||||
if args.data == 'twitter': | |||||
config = TwitterConfig() | |||||
elif args.data == 'weibo': | |||||
config = WeiboConfig() | |||||
else: | |||||
raise Exception('Enter a valid dataset name', args.data) | |||||
if args.batch: | |||||
config.batch_size = args.batch | |||||
if args.epoch: | |||||
config.epochs = args.epoch | |||||
if args.use_optuna: | |||||
config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra) | |||||
else: | |||||
config.output_path += 'logs/' + args.data + '_' + str(args.extra) | |||||
use_optuna = True | |||||
if not args.extra or 'temp' in args.extra: | |||||
config.output_path = str(args.extra) | |||||
use_optuna = False | |||||
try: | |||||
os.mkdir(config.output_path) | |||||
except OSError: | |||||
print("Creation of the directory failed") | |||||
else: | |||||
print("Successfully created the directory") | |||||
config.writer = SummaryWriter(config.output_path) | |||||
if args.use_optuna and use_optuna: | |||||
optuna_main(config, args.use_optuna) | |||||
elif args.just_test: | |||||
test_main(config, args.just_test) | |||||
else: | |||||
torch_main(config) | |||||
config.writer.close() |
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from image import ImageTransformerEncoder, ImageResnetEncoder | |||||
from text import TextEncoder | |||||
class ProjectionHead(nn.Module): | |||||
def __init__( | |||||
self, | |||||
config, | |||||
embedding_dim, | |||||
): | |||||
super().__init__() | |||||
self.config = config | |||||
self.projection = nn.Linear(embedding_dim, config.projection_size) | |||||
self.gelu = nn.GELU() | |||||
self.fc = nn.Linear(config.projection_size, config.projection_size) | |||||
self.dropout = nn.Dropout(config.dropout) | |||||
self.layer_norm = nn.LayerNorm(config.projection_size) | |||||
def forward(self, x): | |||||
projected = self.projection(x) | |||||
x = self.gelu(projected) | |||||
x = self.fc(x) | |||||
x = self.dropout(x) | |||||
x = x + projected | |||||
x = self.layer_norm(x) | |||||
return x | |||||
class Classifier(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size) | |||||
self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size) | |||||
self.gelu = nn.GELU() | |||||
self.drop_out = nn.Dropout(config.dropout) | |||||
self.layer_norm_2 = nn.LayerNorm(config.hidden_size) | |||||
self.classifier_layer = nn.Linear(config.hidden_size, config.class_num) | |||||
self.softmax = nn.Softmax(dim=1) | |||||
def forward(self, x): | |||||
x = self.layer_norm_1(x) | |||||
x = self.linear_layer(x) | |||||
x = self.gelu(x) | |||||
x = self.drop_out(x) | |||||
self.embeddings = x = self.layer_norm_2(x) | |||||
x = self.classifier_layer(x) | |||||
x = self.softmax(x) | |||||
return x | |||||
class FakeNewsModel(nn.Module): | |||||
def __init__( | |||||
self, config | |||||
): | |||||
super().__init__() | |||||
self.config = config | |||||
if 'resnet' in self.config.image_model_name: | |||||
self.image_encoder = ImageResnetEncoder(config) | |||||
else: | |||||
self.image_encoder = ImageTransformerEncoder(config) | |||||
self.text_encoder = TextEncoder(config) | |||||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
self.classifier = Classifier(config) | |||||
self.temperature = config.temperature | |||||
class_weights = torch.FloatTensor(config.class_weights) | |||||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
self.text_embeddings = None | |||||
self.image_embeddings = None | |||||
self.multimodal_embeddings = None | |||||
def forward(self, batch): | |||||
image_features = self.image_encoder(ids=batch['id'], image=batch["image"]) | |||||
text_features = self.text_encoder(ids=batch['id'], | |||||
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |||||
self.image_embeddings = self.image_projection(image_features) | |||||
self.text_embeddings = self.text_projection(text_features) | |||||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
score = self.classifier(self.multimodal_embeddings) | |||||
probs, output = torch.max(score.data, dim=1) | |||||
return output, score | |||||
class FakeNewsEmbeddingModel(nn.Module): | |||||
def __init__( | |||||
self, config | |||||
): | |||||
super().__init__() | |||||
self.config = config | |||||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
self.classifier = Classifier(config) | |||||
self.temperature = config.temperature | |||||
class_weights = torch.FloatTensor(config.class_weights) | |||||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
def forward(self, batch): | |||||
self.image_embeddings = self.image_projection(batch['image_embedding']) | |||||
self.text_embeddings = self.text_projection(batch['text_embedding']) | |||||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
score = self.classifier(self.multimodal_embeddings) | |||||
probs, output = torch.max(score.data, dim=1) | |||||
return output, score | |||||
def calculate_loss(model, score, label): | |||||
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||||
# fake = (label == 1).nonzero() | |||||
# real = (label == 0).nonzero() | |||||
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean() | |||||
c_loss = model.classifier_loss_function(score, label) | |||||
loss = model.config.loss_weight * c_loss + similarity.mean() | |||||
return loss | |||||
def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||||
# Calculating the Loss | |||||
logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||||
images_similarity = image_embeddings @ image_embeddings.T | |||||
texts_similarity = text_embeddings @ text_embeddings.T | |||||
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||||
texts_loss = cross_entropy(logits, targets, reduction='none') | |||||
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||||
loss = (images_loss + texts_loss) / 2.0 | |||||
return loss | |||||
def cross_entropy(preds, targets, reduction='none'): | |||||
log_softmax = nn.LogSoftmax(dim=-1) | |||||
loss = (-targets * log_softmax(preds)).sum(1) | |||||
if reduction == "none": | |||||
return loss | |||||
elif reduction == "mean": | |||||
return loss.mean() |
import joblib | |||||
import optuna | |||||
from optuna.pruners import MedianPruner | |||||
from optuna.trial import TrialState | |||||
from data_loaders import make_dfs, build_loaders | |||||
from learner import supervised_train | |||||
from test_main import test | |||||
def objective(trial, config, train_loader, validation_loader): | |||||
config.optuna(trial=trial) | |||||
print('Trial', trial.number, 'parameters', trial.params) | |||||
accuracy = supervised_train(config, train_loader, validation_loader, trial=trial) | |||||
return accuracy | |||||
def optuna_main(config, n_trials=100): | |||||
train_df, test_df, validation_df, _ = make_dfs(config) | |||||
train_loader = build_loaders(config, train_df, mode="train") | |||||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||||
test_loader = build_loaders(config, test_df, mode="test") | |||||
study = optuna.create_study(study_name=config.output_path.split('/')[-1], | |||||
sampler=optuna.samplers.TPESampler(), | |||||
storage=f'sqlite:///{config.output_path + "/optuna.db"}', | |||||
load_if_exists=True, | |||||
direction="maximize", | |||||
pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10) | |||||
) | |||||
study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials) | |||||
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) | |||||
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) | |||||
joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl') | |||||
print("Study statistics: ") | |||||
print(" Number of finished trials: ", len(study.trials)) | |||||
print(" Number of pruned trials: ", len(pruned_trials)) | |||||
print(" Number of complete trials: ", len(complete_trials)) | |||||
s = '' | |||||
print("Best trial:") | |||||
trial = study.best_trial | |||||
print(' Number: ', trial.number) | |||||
print(" Value: ", trial.value) | |||||
s += 'value: ' + str(trial.value) + '\n' | |||||
print(" Params: ") | |||||
s += 'params: \n' | |||||
for key, value in trial.params.items(): | |||||
print(" {}: {}".format(key, value)) | |||||
s += " {}: {}\n".format(key, value) | |||||
with open(config.output_path+'/optuna_results.txt', 'w') as f: | |||||
f.write(s) | |||||
test(config, test_loader, trial_number=trial.number) | |||||
opencv-python==4.5.1.48 | |||||
optuna==2.8.0 | |||||
pandas==1.1.5 | |||||
Pillow==8.2.0 | |||||
pytorch-lightning==1.2.10 | |||||
pytorch-model-summary==0.1.2 | |||||
pytorchtools==0.0.2 | |||||
scikit-image==0.17.2 | |||||
scikit-learn==0.24.2 | |||||
scipy==1.5.4 | |||||
seaborn==0.11.1 | |||||
torch==1.8.1 | |||||
torchmetrics==0.2.0 | |||||
torchsummary==1.5.1 | |||||
torchvision==0.9.1 | |||||
tqdm==4.60.0 | |||||
transformers==4.5.1 |
import random | |||||
import numpy as np | |||||
import torch | |||||
from tqdm import tqdm | |||||
from data_loaders import make_dfs, build_loaders | |||||
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca | |||||
from learner import batch_constructor | |||||
from model import FakeNewsModel | |||||
def test(config, test_loader, trial_number=None): | |||||
if trial_number: | |||||
try: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||||
except: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||||
else: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||||
try: | |||||
parameters = checkpoint['parameters'] | |||||
config.assign_hyperparameters(parameters) | |||||
except: | |||||
pass | |||||
model = FakeNewsModel(config).to(config.device) | |||||
try: | |||||
model.load_state_dict(checkpoint['model_state_dict']) | |||||
except: | |||||
model.load_state_dict(checkpoint) | |||||
model.eval() | |||||
torch.manual_seed(27) | |||||
random.seed(27) | |||||
np.random.seed(27) | |||||
image_features = [] | |||||
text_features = [] | |||||
multimodal_features = [] | |||||
concat_features = [] | |||||
targets = [] | |||||
predictions = [] | |||||
scores = [] | |||||
tqdm_object = tqdm(test_loader, total=len(test_loader)) | |||||
for i, batch in enumerate(tqdm_object): | |||||
batch = batch_constructor(config, batch) | |||||
with torch.no_grad(): | |||||
output, score = model(batch) | |||||
prediction = output.detach() | |||||
predictions.append(prediction) | |||||
score = score.detach() | |||||
scores.append(score) | |||||
target = batch['label'].detach() | |||||
targets.append(target) | |||||
image_feature = model.image_embeddings.detach() | |||||
image_features.append(image_feature) | |||||
text_feature = model.text_embeddings.detach() | |||||
text_features.append(text_feature) | |||||
multimodal_feature = model.multimodal_embeddings.detach() | |||||
multimodal_features.append(multimodal_feature) | |||||
concat_feature = model.classifier.embeddings.detach() | |||||
concat_features.append(concat_feature) | |||||
# config.writer.add_graph(model, input_to_model=batch, verbose=True) | |||||
s = '' | |||||
s += report_per_class(targets, predictions) + '\n' | |||||
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||||
with open(config.output_path + '/results.txt', 'w') as f: | |||||
f.write(s) | |||||
roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||||
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||||
# saving_in_tensorboard(config, image_features, targets, 'image_features') | |||||
plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png') | |||||
plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png') | |||||
# saving_in_tensorboard(config, text_features, targets, 'text_features') | |||||
plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png') | |||||
plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png') | |||||
# | |||||
# saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features') | |||||
plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png') | |||||
plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png') | |||||
# saving_in_tensorboard(config, concat_features, targets, 'concat_features') | |||||
plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png') | |||||
plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png') | |||||
config_parameters = str(config) | |||||
with open(config.output_path + '/parameters.txt', 'w') as f: | |||||
f.write(config_parameters) | |||||
print(config) | |||||
def test_main(config, trial_number=None): | |||||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
test_loader = build_loaders(config, test_df, mode="test") | |||||
test(config, test_loader, trial_number) |
from torch import nn | |||||
from transformers import BertModel, BertConfig, BertTokenizer, \ | |||||
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||||
XLNetModel, XLNetConfig, XLNetTokenizer | |||||
class TextEncoder(nn.Module): | |||||
def __init__(self, config): | |||||
super().__init__() | |||||
self.config = config | |||||
self.model = get_text_model(config) | |||||
for p in self.model.parameters(): | |||||
p.requires_grad = self.config.trainable | |||||
self.target_token_idx = 0 | |||||
self.text_encoder_embedding = dict() | |||||
def forward(self, ids, input_ids, attention_mask): | |||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |||||
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||||
# for i, id in enumerate(ids): | |||||
# id = int(id.detach().cpu().numpy()) | |||||
# self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||||
return last_hidden_state | |||||
def get_text_model(config): | |||||
if 'bigbird' in config.text_encoder_model: | |||||
if config.pretrained: | |||||
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||||
else: | |||||
model = BigBirdModel(config=BigBirdConfig()) | |||||
elif 'xlnet' in config.text_encoder_model: | |||||
if config.pretrained: | |||||
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
output_hidden_states=True, return_dict=True) | |||||
else: | |||||
model = XLNetModel(config=XLNetConfig()) | |||||
else: | |||||
if config.pretrained: | |||||
model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
output_hidden_states=True, return_dict=True) | |||||
else: | |||||
model = BertModel(config=BertConfig()) | |||||
return model | |||||
from data_loaders import build_loaders, make_dfs | |||||
from learner import supervised_train | |||||
from test_main import test | |||||
def torch_main(config): | |||||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
train_loader = build_loaders(config, train_df, mode="train") | |||||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||||
test_loader = build_loaders(config, test_df, mode="test") | |||||
supervised_train(config, train_loader, validation_loader) | |||||
test(config, test_loader) | |||||
import numpy as np | |||||
import torch | |||||
class AvgMeter: | |||||
def __init__(self, name="Metric"): | |||||
self.name = name | |||||
self.reset() | |||||
def reset(self): | |||||
self.avg, self.sum, self.count = [0] * 3 | |||||
def update(self, val, count=1): | |||||
self.count += count | |||||
self.sum += val * count | |||||
self.avg = self.sum / self.count | |||||
def __repr__(self): | |||||
text = f"{self.name}: {self.avg:.4f}" | |||||
return text | |||||
def print_lr(optimizer): | |||||
for param_group in optimizer.param_groups: | |||||
print(param_group['name'], param_group['lr']) | |||||
class CheckpointSaving: | |||||
def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print): | |||||
self.best_score = None | |||||
self.val_acc_max = 0 | |||||
self.path = path | |||||
self.verbose = verbose | |||||
self.trace_func = trace_func | |||||
def __call__(self, val_acc, model): | |||||
if self.best_score is None: | |||||
self.best_score = val_acc | |||||
self.save_checkpoint(val_acc, model) | |||||
elif val_acc > self.best_score: | |||||
self.best_score = val_acc | |||||
self.save_checkpoint(val_acc, model) | |||||
def save_checkpoint(self, val_acc, model): | |||||
if self.verbose: | |||||
self.trace_func( | |||||
f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...') | |||||
torch.save(model.state_dict(), self.path) | |||||
self.val_acc_max = val_acc | |||||
class EarlyStopping: | |||||
def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print): | |||||
self.patience = patience | |||||
self.verbose = verbose | |||||
self.counter = 0 | |||||
self.best_score = None | |||||
self.early_stop = False | |||||
self.val_loss_min = np.Inf | |||||
self.delta = delta | |||||
self.path = path | |||||
self.trace_func = trace_func | |||||
def __call__(self, val_loss, model): | |||||
score = -val_loss | |||||
if self.best_score is None: | |||||
self.best_score = score | |||||
if self.verbose: | |||||
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||||
self.val_loss_min = val_loss | |||||
# self.save_checkpoint(val_loss, model) | |||||
elif score < self.best_score + self.delta: | |||||
self.counter += 1 | |||||
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |||||
if self.counter >= self.patience: | |||||
self.early_stop = True | |||||
else: | |||||
self.best_score = score | |||||
if self.verbose: | |||||
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||||
# self.save_checkpoint(val_loss, model) | |||||
self.val_loss_min = val_loss | |||||
self.counter = 0 | |||||
# def save_checkpoint(self, val_loss, model): | |||||
# if self.verbose: | |||||
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||||
# torch.save(model.state_dict(), self.path) | |||||
# self.val_loss_min = val_loss |