@@ -114,3 +114,6 @@ dmypy.json | |||
# Pyre type checker | |||
.pyre/ | |||
.idea/ | |||
env/ | |||
venv/ |
@@ -1,3 +1,48 @@ | |||
# FakeNewsRevealer | |||
Official implementation of the Fake News Revealer paper | |||
Official implementation of the "Fake News Revealer (FNR): A Similarity and Transformer-Based Approach to Detect | |||
Multi-Modal Fake News in Social Media" paper [(ArXiv Link)](https://arxiv.org/pdf/2112.01131.pdf). | |||
## Requirments | |||
FNR is built in Python 3.6 using PyTorch 1.8. Please use the following command to install the requirements: | |||
``` | |||
pip install -r requirements.txt | |||
``` | |||
## How to Run | |||
First, place the data address and configuration into the config file in the data directory, and then follow the train | |||
and test commands. | |||
### train | |||
To run with Optuna for parameter tuning use this command: | |||
``` | |||
python main --data "DATA NAME" --use_optuna "NUMBER OF OPTUNA TRIALS" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||
``` | |||
To run without parameter tuning, adjust your parameters in the config file and then use the below command: | |||
``` | |||
python main --data "DATA NAME" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||
``` | |||
### test | |||
In the test step, at first, make sure to have the requested 'checkpoint' file then run the following line: | |||
``` | |||
python main --data "DATA NAME" --just_test "REQESTED TRIAL NUMBER" | |||
``` | |||
## Bibtex | |||
Cite our paper using the following bibtex item: | |||
``` | |||
@misc{ghorbanpour2021fnr, | |||
title={FNR: A Similarity and Transformer-Based Approach to Detect Multi-Modal Fake News in Social Media}, | |||
author={Faeze Ghorbanpour and Maryam Ramezani and Mohammad A. Fazli and Hamid R. Rabiee}, | |||
year={2021}, | |||
eprint={2112.01131}, | |||
archivePrefix={arXiv}, | |||
} | |||
``` | |||
@@ -0,0 +1,117 @@ | |||
import torch | |||
import ast | |||
class Config: | |||
name = '' | |||
DatasetLoader = None | |||
debug = False | |||
data_path = '' | |||
output_path = '' | |||
train_image_path = '' | |||
validation_image_path = '' | |||
test_image_path = '' | |||
train_text_path = '' | |||
validation_text_path = '' | |||
test_text_path = '' | |||
train_text_embedding_path = '' | |||
validation_text_embedding_path = '' | |||
train_image_embedding_path = '' | |||
validation_image_embedding_path = '' | |||
batch_size = 256 | |||
epochs = 50 | |||
num_workers = 2 | |||
head_lr = 1e-03 | |||
image_encoder_lr = 1e-04 | |||
text_encoder_lr = 1e-04 | |||
attention_lr = 1e-3 | |||
classification_lr = 1e-03 | |||
max_grad_norm = 5.0 | |||
head_weight_decay = 0.001 | |||
attention_weight_decay = 0.001 | |||
classification_weight_decay = 0.001 | |||
patience = 30 | |||
delta = 0.0000001 | |||
factor = 0.8 | |||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
image_embedding = 768 | |||
num_img_region = 64 # 16 #TODO | |||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
text_embedding = 768 | |||
max_length = 32 | |||
pretrained = True # for both image encoder and text encoder | |||
trainable = False # for both image encoder and text encoder | |||
temperature = 1.0 | |||
# image size | |||
size = 224 | |||
num_projection_layers = 1 | |||
projection_size = 256 | |||
dropout = 0.3 | |||
hidden_size = 256 | |||
num_region = 64 # 16 #TODO | |||
region_size = projection_size // num_region | |||
classes = ['real', 'fake'] | |||
class_num = 2 | |||
loss_weight = 1 | |||
class_weights = [1, 1] | |||
writer = None | |||
has_unlabeled_data = False | |||
step = 0 | |||
T1 = 10 | |||
T2 = 150 | |||
af = 3 | |||
wanted_accuracy = 0.76 | |||
def optuna(self, trial): | |||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1) | |||
self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
def __str__(self): | |||
s = '' | |||
members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] | |||
for member in members: | |||
s += member + '\t' + str(getattr(self, member)) + '\n' | |||
return s | |||
def assign_hyperparameters(self, s): | |||
for line in s.split('\n'): | |||
s = line.split('\t') | |||
try: | |||
attr = getattr(self, s[0]) | |||
if type(attr) not in [list, set, dict]: | |||
setattr(self, s[0], type(attr)(s[1])) | |||
else: | |||
setattr(self, s[0], ast.literal_eval(s[1])) | |||
except: | |||
print(s[0], 'doesnot exist') |
@@ -0,0 +1,33 @@ | |||
import torch | |||
from torch.utils.data import Dataset | |||
class DatasetLoader(Dataset): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer( | |||
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||
) | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
def set_text(self, idx): | |||
item = { | |||
key: values[idx].clone().detach() | |||
for key, values in self.encoded_text.items() | |||
} | |||
item['text'] = self.text[idx] | |||
item['label'] = self.labels[idx] | |||
item['id'] = idx | |||
return item | |||
def set_image(self, image, item): | |||
if 'resnet' in self.config.image_model_name: | |||
image = self.transforms(image=image)['image'] | |||
item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||
else: | |||
image = self.transforms(images=image, return_tensors='pt') | |||
image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||
item['image'] = image.reshape((3, 224, 224)) |
@@ -0,0 +1,65 @@ | |||
import torch | |||
from data.config import Config | |||
from data.twitter.data_loader import TwitterDatasetLoader | |||
class TwitterConfig(Config): | |||
name = 'twitter' | |||
DatasetLoader = TwitterDatasetLoader | |||
data_path = '/twitter/' | |||
output_path = '' | |||
train_image_path = data_path + 'images_train/' | |||
validation_image_path = data_path + 'images_validation/' | |||
test_image_path = data_path + 'images_test/' | |||
train_text_path = data_path + 'twitter_train_translated.csv' | |||
validation_text_path = data_path + 'twitter_validation_translated.csv' | |||
test_text_path = data_path + 'twitter_test_translated.csv' | |||
batch_size = 128 | |||
epochs = 100 | |||
num_workers = 2 | |||
head_lr = 1e-03 | |||
image_encoder_lr = 1e-04 | |||
text_encoder_lr = 1e-04 | |||
attention_lr = 1e-3 | |||
classification_lr = 1e-03 | |||
head_weight_decay = 0.001 | |||
attention_weight_decay = 0.001 | |||
classification_weight_decay = 0.001 | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
image_model_name = 'vit-base-patch16-224' | |||
image_embedding = 768 | |||
text_encoder_model = "bert-base-uncased" | |||
text_tokenizer = "bert-base-uncased" | |||
text_embedding = 768 | |||
max_length = 32 | |||
pretrained = True | |||
trainable = False | |||
temperature = 1.0 | |||
classes = ['real', 'fake'] | |||
class_weights = [1, 1] | |||
wanted_accuracy = 0.76 | |||
def optuna(self, trial): | |||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
# self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) |
@@ -0,0 +1,65 @@ | |||
import cv2 | |||
import pickle | |||
from PIL import Image | |||
from numpy import asarray | |||
from data.data_loader import DatasetLoader | |||
class TwitterDatasetLoader(DatasetLoader): | |||
def __getitem__(self, idx): | |||
item = self.set_text(idx) | |||
try: | |||
if self.mode == 'train': | |||
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||
elif self.mode == 'validation': | |||
image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}") | |||
else: | |||
image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}") | |||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||
except: | |||
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
image = asarray(image) | |||
self.set_image(image, item) | |||
return item | |||
def __len__(self): | |||
return len(self.text) | |||
class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
if mode == 'train': | |||
with open(config.train_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.train_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
else: | |||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
def __getitem__(self, idx): | |||
item = dict() | |||
item['id'] = idx | |||
item['image_embedding'] = self.image_embedding[idx] | |||
item['text_embedding'] = self.text_embedding[idx] | |||
item['label'] = self.labels[idx] | |||
return item | |||
def __len__(self): | |||
return len(self.text) |
@@ -0,0 +1,62 @@ | |||
import torch | |||
from transformers import BertTokenizer, BertModel, BertConfig | |||
from data.config import Config | |||
from data.weibo.data_loader import WeiboDatasetLoader | |||
class WeiboConfig(Config): | |||
name = 'weibo' | |||
DatasetLoader = WeiboDatasetLoader | |||
data_path = 'weibo/' | |||
output_path = '' | |||
rumor_image_path = data_path + 'rumor_images/' | |||
nonrumor_image_path = data_path + 'nonrumor_images/' | |||
train_text_path = data_path + 'weibo_train.csv' | |||
validation_text_path = data_path + 'weibo_validation.csv' | |||
test_text_path = data_path + 'weibo_test.csv' | |||
batch_size = 128 | |||
epochs = 100 | |||
num_workers = 2 | |||
head_lr = 1e-03 | |||
image_encoder_lr = 1e-02 | |||
text_encoder_lr = 1e-05 | |||
weight_decay = 0.001 | |||
classification_lr = 1e-02 | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||
image_embedding = 768 # 2048 | |||
num_img_region = 64 # TODO | |||
text_encoder_model = "bert-base-chinese" | |||
text_tokenizer = "bert-base-chinese" | |||
text_embedding = 768 | |||
max_length = 200 | |||
pretrained = True | |||
trainable = False | |||
temperature = 1.0 | |||
labels = ['real', 'fake'] | |||
wanted_accuracy = 0.80 | |||
def optuna(self, trial): | |||
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
@@ -0,0 +1,66 @@ | |||
import pickle | |||
import cv2 | |||
from PIL import Image | |||
from numpy import asarray | |||
from data.data_loader import DatasetLoader | |||
class WeiboDatasetLoader(DatasetLoader): | |||
def __getitem__(self, idx): | |||
item = self.set_text(idx) | |||
if self.labels[idx] == 1: | |||
try: | |||
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||
except: | |||
image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
image = asarray(image) | |||
else: | |||
try: | |||
image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}") | |||
except: | |||
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
image = asarray(image) | |||
self.set_image(image, item) | |||
return item | |||
def __len__(self): | |||
return len(self.text) | |||
class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
if mode == 'train': | |||
with open(config.train_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.train_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
else: | |||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
def __getitem__(self, idx): | |||
item = dict() | |||
item['id'] = idx | |||
item['image_embedding'] = self.image_embedding[idx] | |||
item['text_embedding'] = self.text_embedding[idx] | |||
item['label'] = self.labels[idx] | |||
return item | |||
def __len__(self): | |||
return len(self.text) |
@@ -0,0 +1,108 @@ | |||
import albumentations as A | |||
import numpy as np | |||
import pandas as pd | |||
from sklearn.utils import compute_class_weight | |||
from torch.utils.data import DataLoader | |||
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
def get_transforms(config, mode="train"): | |||
if mode == "train": | |||
return A.Compose( | |||
[ | |||
A.Resize(config.size, config.size, always_apply=True), | |||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
] | |||
) | |||
else: | |||
return A.Compose( | |||
[ | |||
A.Resize(config.size, config.size, always_apply=True), | |||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
] | |||
) | |||
def make_dfs(config): | |||
train_dataframe = pd.read_csv(config.train_text_path) | |||
train_dataframe.dropna(subset=['text'], inplace=True) | |||
train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True) | |||
train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
config.class_weights = get_class_weights(train_dataframe.label.values) | |||
if config.test_text_path is None: | |||
offset = int(train_dataframe.shape[0] * 0.80) | |||
test_dataframe = train_dataframe[offset:] | |||
train_dataframe = train_dataframe[:offset] | |||
else: | |||
test_dataframe = pd.read_csv(config.test_text_path) | |||
test_dataframe.dropna(subset=['text'], inplace=True) | |||
test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||
test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
if config.validation_text_path is None: | |||
offset = int(train_dataframe.shape[0] * 0.90) | |||
validation_dataframe = train_dataframe[offset:] | |||
train_dataframe = train_dataframe[:offset] | |||
else: | |||
validation_dataframe = pd.read_csv(config.validation_text_path) | |||
validation_dataframe.dropna(subset=['text'], inplace=True) | |||
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
unlabeled_dataframe = None | |||
if config.has_unlabeled_data: | |||
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||
return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||
def get_tokenizer(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
elif 'xlnet' in config.text_encoder_model: | |||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
else: | |||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
return tokenizer | |||
def build_loaders(config, dataframe, mode): | |||
if 'resnet' in config.image_model_name: | |||
transforms = get_transforms(config, mode=mode) | |||
else: | |||
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
dataset = config.DatasetLoader( | |||
config, | |||
dataframe["image"].values, | |||
dataframe["text"].values, | |||
dataframe["label"].values, | |||
tokenizer=get_tokenizer(config), | |||
transforms=transforms, | |||
mode=mode, | |||
) | |||
if mode != 'train': | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_size=config.batch_size, | |||
num_workers=config.num_workers, | |||
pin_memory=False, | |||
shuffle=False, | |||
) | |||
else: | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_size=config.batch_size, | |||
num_workers=config.num_workers, | |||
pin_memory=True, | |||
shuffle=True, | |||
) | |||
return dataloader | |||
def get_class_weights(y): | |||
class_weights = compute_class_weight('balanced', np.unique(y), y) | |||
return class_weights |
@@ -0,0 +1,257 @@ | |||
from itertools import cycle | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
import pandas as pd | |||
import seaborn as sns | |||
from numpy import interp | |||
from sklearn.decomposition import PCA | |||
from sklearn.manifold import TSNE | |||
from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score | |||
from sklearn.metrics import classification_report, roc_curve, auc | |||
from sklearn.preprocessing import OneHotEncoder | |||
def metrics(truth, pred, prob, file_path): | |||
truth = [i.cpu().numpy() for i in truth] | |||
pred = [i.cpu().numpy() for i in pred] | |||
prob = [i.cpu().numpy() for i in prob] | |||
pred = np.concatenate(pred, axis=0) | |||
truth = np.concatenate(truth, axis=0) | |||
prob = np.concatenate(prob, axis=0) | |||
prob = prob[:, 1] | |||
f_score_micro = f1_score(truth, pred, average='micro', zero_division=0) | |||
f_score_macro = f1_score(truth, pred, average='macro', zero_division=0) | |||
f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0) | |||
accuarcy = accuracy_score(truth, pred) | |||
s = '' | |||
print('accuracy', accuarcy) | |||
s += 'accuracy' + str(accuarcy) + '\n' | |||
print('f_score_micro', f_score_micro) | |||
s += 'f_score_micro' + str(f_score_micro) + '\n' | |||
print('f_score_macro', f_score_macro) | |||
s += 'f_score_macro' + str(f_score_macro) + '\n' | |||
print('f_score_weighted', f_score_weighted) | |||
s += 'f_score_weighted' + str(f_score_weighted) + '\n' | |||
fpr, tpr, thresholds = roc_curve(truth, prob) | |||
AUC = auc(fpr, tpr) | |||
print('AUC', AUC) | |||
s += 'AUC' + str(AUC) + '\n' | |||
df = pd.DataFrame(dict(fpr=fpr, tpr=tpr)) | |||
df.to_csv(file_path) | |||
return s | |||
def report_per_class(truth, pred): | |||
truth = [i.cpu().numpy() for i in truth] | |||
pred = [i.cpu().numpy() for i in pred] | |||
pred = np.concatenate(pred, axis=0) | |||
truth = np.concatenate(truth, axis=0) | |||
report = classification_report(truth, pred, zero_division=0, output_dict=True) | |||
s = '' | |||
class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']] | |||
for class_label in class_labels: | |||
print('class_label', class_label) | |||
s += 'class_label' + str(class_label) + '\n' | |||
s += str(report[class_label]) | |||
print(report[class_label]) | |||
return s | |||
def multiclass_acc(truth, pred): | |||
truth = [i.cpu().numpy() for i in truth] | |||
pred = [i.cpu().numpy() for i in pred] | |||
pred = np.concatenate(pred, axis=0) | |||
truth = np.concatenate(truth, axis=0) | |||
return accuracy_score(truth, pred) | |||
def roc_auc_plot(truth, score, num_class=2, fname='roc.png'): | |||
truth = [i.cpu().numpy() for i in truth] | |||
score = [i.cpu().numpy() for i in score] | |||
truth = np.concatenate(truth, axis=0) | |||
score = np.concatenate(score, axis=0) | |||
enc = OneHotEncoder(handle_unknown='ignore') | |||
enc.fit(truth.reshape(-1, 1)) | |||
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||
fpr_dict = dict() | |||
tpr_dict = dict() | |||
roc_auc_dict = dict() | |||
for i in range(num_class): | |||
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i]) | |||
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) | |||
# micro | |||
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel()) | |||
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) | |||
# macro | |||
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)])) | |||
mean_tpr = np.zeros_like(all_fpr) | |||
for i in range(num_class): | |||
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) | |||
mean_tpr /= num_class | |||
fpr_dict["macro"] = all_fpr | |||
tpr_dict["macro"] = mean_tpr | |||
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) | |||
plt.figure() | |||
lw = 2 | |||
plt.plot(fpr_dict["micro"], tpr_dict["micro"], | |||
label='micro-average ROC curve (area = {0:0.2f})' | |||
''.format(roc_auc_dict["micro"]), | |||
color='deeppink', linestyle=':', linewidth=4) | |||
plt.plot(fpr_dict["macro"], tpr_dict["macro"], | |||
label='macro-average ROC curve (area = {0:0.2f})' | |||
''.format(roc_auc_dict["macro"]), | |||
color='navy', linestyle=':', linewidth=4) | |||
colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||
for i, color in zip(range(num_class), colors): | |||
plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, | |||
label='ROC curve of class {0} (area = {1:0.2f})' | |||
''.format(i, roc_auc_dict[i])) | |||
plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||
plt.xlim([0.0, 1.0]) | |||
plt.ylim([0.0, 1.05]) | |||
plt.xlabel('False Positive Rate') | |||
plt.ylabel('True Positive Rate') | |||
plt.legend(loc="lower right") | |||
plt.savefig(fname) | |||
# plt.show() | |||
def precision_recall_plot(truth, score, num_class=2, fname='pr.png'): | |||
truth = [i.cpu().numpy() for i in truth] | |||
score = [i.cpu().numpy() for i in score] | |||
truth = np.concatenate(truth, axis=0) | |||
score = np.concatenate(score, axis=0) | |||
enc = OneHotEncoder(handle_unknown='ignore') | |||
enc.fit(truth.reshape(-1, 1)) | |||
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||
# Call the Sklearn library, calculate the precision and recall corresponding to each category | |||
precision_dict = dict() | |||
recall_dict = dict() | |||
average_precision_dict = dict() | |||
for i in range(num_class): | |||
precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i]) | |||
average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i]) | |||
print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i]) | |||
# micro | |||
precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(), | |||
score.ravel()) | |||
average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro") | |||
# macro | |||
all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)])) | |||
mean_tpr = np.zeros_like(all_fpr) | |||
for i in range(num_class): | |||
mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i]) | |||
mean_tpr /= num_class | |||
precision_dict["macro"] = all_fpr | |||
recall_dict["macro"] = mean_tpr | |||
average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"]) | |||
plt.figure() | |||
plt.subplots(figsize=(16, 10)) | |||
lw = 2 | |||
plt.plot(precision_dict["micro"], recall_dict["micro"], | |||
label='micro-average Precision-Recall curve (area = {0:0.2f})' | |||
''.format(average_precision_dict["micro"]), | |||
color='deeppink', linestyle=':', linewidth=4) | |||
plt.plot(precision_dict["macro"], recall_dict["macro"], | |||
label='macro-average Precision-Recall curve (area = {0:0.2f})' | |||
''.format(average_precision_dict["macro"]), | |||
color='navy', linestyle=':', linewidth=4) | |||
colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||
for i, color in zip(range(num_class), colors): | |||
plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw, | |||
label='Precision-Recall curve of class {0} (area = {1:0.2f})' | |||
''.format(i, average_precision_dict[i])) | |||
plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||
plt.xlabel('Recall') | |||
plt.ylabel('Precision') | |||
plt.ylim([0.0, 1.05]) | |||
plt.xlim([0.0, 1.0]) | |||
plt.legend(loc="lower left") | |||
plt.savefig(fname=fname) | |||
# plt.show() | |||
def saving_in_tensorboard(config, x, y, fname='embedding'): | |||
x = [i.cpu().numpy() for i in x] | |||
y = [i.cpu().numpy() for i in y] | |||
x = np.concatenate(x, axis=0) | |||
y = np.concatenate(y, axis=0) | |||
z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
# config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname) | |||
def plot_tsne(config, x, y, fname='tsne.png'): | |||
x = [i.cpu().numpy() for i in x] | |||
y = [i.cpu().numpy() for i in y] | |||
x = np.concatenate(x, axis=0) | |||
y = np.concatenate(y, axis=0) | |||
y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000) | |||
tsne_proj = tsne.fit_transform(x) | |||
fig, ax = plt.subplots(figsize=(16, 10)) | |||
palette = sns.color_palette("bright", 2) | |||
sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette) | |||
ax.legend(fontsize='large', markerscale=2) | |||
plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0])) | |||
plt.savefig(fname=fname) | |||
# plt.show() | |||
def plot_pca(config, x, y, fname='pca.png'): | |||
x = [i.cpu().numpy() for i in x] | |||
y = [i.cpu().numpy() for i in y] | |||
x = np.concatenate(x, axis=0) | |||
y = np.concatenate(y, axis=0) | |||
y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
pca = PCA(n_components=2) | |||
pca_proj = pca.fit_transform(x) | |||
fig, ax = plt.subplots(figsize=(16, 10)) | |||
palette = sns.color_palette("bright", 2) | |||
sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette) | |||
ax.legend(fontsize='large', markerscale=2) | |||
plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0])) | |||
plt.savefig(fname=fname) | |||
# plt.show() |
@@ -0,0 +1,45 @@ | |||
import timm | |||
from torch import nn | |||
from transformers import ViTModel, ViTConfig | |||
class ImageTransformerEncoder(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
if config.pretrained: | |||
self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False, | |||
output_hidden_states=True, return_dict=True) | |||
else: | |||
self.model = ViTModel(config=ViTConfig()) | |||
for p in self.model.parameters(): | |||
p.requires_grad = config.trainable | |||
self.target_token_idx = 0 | |||
self.image_encoder_embedding = dict() | |||
def forward(self, ids, image): | |||
output = self.model(image) | |||
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||
# for i, id in enumerate(ids): | |||
# id = int(id.detach().cpu().numpy()) | |||
# self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||
return last_hidden_state | |||
class ImageResnetEncoder(nn.Module): | |||
def __init__( | |||
self, config | |||
): | |||
super().__init__() | |||
self.model = timm.create_model( | |||
config.image_model_name, config.pretrained, num_classes=0, global_pool="avg" | |||
) | |||
for p in self.model.parameters(): | |||
p.requires_grad = config.trainable | |||
def forward(self, ids, image): | |||
return self.model(image) |
@@ -0,0 +1,206 @@ | |||
import gc | |||
import itertools | |||
import random | |||
import numpy as np | |||
import optuna | |||
import pandas as pd | |||
import torch | |||
from evaluation import multiclass_acc | |||
from model import FakeNewsModel, calculate_loss | |||
from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving | |||
def batch_constructor(config, batch): | |||
b = {} | |||
for k, v in batch.items(): | |||
if k != 'text': | |||
b[k] = v.to(config.device) | |||
return b | |||
def train_epoch(config, model, train_loader, optimizer, scalar): | |||
loss_meter = AvgMeter('train') | |||
# tqdm_object = tqdm(train_loader, total=len(train_loader)) | |||
targets = [] | |||
predictions = [] | |||
for index, batch in enumerate(train_loader): | |||
batch = batch_constructor(config, batch) | |||
optimizer.zero_grad(set_to_none=True) | |||
with torch.cuda.amp.autocast(): | |||
output, score = model(batch) | |||
loss = calculate_loss(model, score, batch['label']) | |||
scalar.scale(loss).backward() | |||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |||
if (index + 1) % 2: | |||
scalar.step(optimizer) | |||
# loss.backward() | |||
# optimizer.step() | |||
scalar.update() | |||
count = batch["id"].size(0) | |||
loss_meter.update(loss.detach(), count) | |||
# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |||
prediction = output.detach() | |||
predictions.append(prediction) | |||
target = batch['label'].detach() | |||
targets.append(target) | |||
return loss_meter, targets, predictions | |||
def validation_epoch(config, model, validation_loader): | |||
loss_meter = AvgMeter('validation') | |||
targets = [] | |||
predictions = [] | |||
# tqdm_object = tqdm(validation_loader, total=len(validation_loader)) | |||
for batch in validation_loader: | |||
batch = batch_constructor(config, batch) | |||
with torch.no_grad(): | |||
output, score = model(batch) | |||
loss = calculate_loss(model, score, batch['label']) | |||
count = batch["id"].size(0) | |||
loss_meter.update(loss.detach(), count) | |||
# tqdm_object.set_postfix(validation_loss=loss_meter.avg) | |||
prediction = output.detach() | |||
predictions.append(prediction) | |||
target = batch['label'].detach() | |||
targets.append(target) | |||
return loss_meter, targets, predictions | |||
def supervised_train(config, train_loader, validation_loader, trial=None): | |||
torch.cuda.empty_cache() | |||
checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt' | |||
if trial: | |||
checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt' | |||
torch.manual_seed(27) | |||
random.seed(27) | |||
np.random.seed(27) | |||
torch.autograd.set_detect_anomaly(False) | |||
torch.autograd.profiler.profile(False) | |||
torch.autograd.profiler.emit_nvtx(False) | |||
scalar = torch.cuda.amp.GradScaler() | |||
model = FakeNewsModel(config).to(config.device) | |||
params = [ | |||
{"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'}, | |||
{"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'}, | |||
{"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), | |||
"lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'}, | |||
{"params": model.classifier.parameters(), "lr": config.classification_lr, | |||
"weight_decay": config.classification_weight_decay, | |||
'name': 'classifier'} | |||
] | |||
optimizer = torch.optim.AdamW(params, amsgrad=True) | |||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor, | |||
patience=config.patience // 5, verbose=True) | |||
early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True) | |||
checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True) | |||
train_losses, train_accuracies = [], [] | |||
validation_losses, validation_accuracies = [], [] | |||
validation_accuracy, validation_loss = 0, 1 | |||
for epoch in range(config.epochs): | |||
print(f"Epoch: {epoch + 1}") | |||
gc.collect() | |||
model.train() | |||
train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar) | |||
model.eval() | |||
with torch.no_grad(): | |||
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||
train_accuracy = multiclass_acc(train_truth, train_pred) | |||
validation_accuracy = multiclass_acc(validation_truth, validation_pred) | |||
print_lr(optimizer) | |||
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy) | |||
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) | |||
if lr_scheduler: | |||
lr_scheduler.step(validation_loss.avg) | |||
if early_stopping: | |||
early_stopping(validation_loss.avg, model) | |||
if early_stopping.early_stop: | |||
print("Early stopping") | |||
break | |||
if checkpoint_saving: | |||
checkpoint_saving(validation_accuracy, model) | |||
train_accuracies.append(train_accuracy) | |||
train_losses.append(train_loss) | |||
validation_accuracies.append(validation_accuracy) | |||
validation_losses.append(validation_loss) | |||
if trial: | |||
trial.report(validation_accuracy, epoch) | |||
if trial.should_prune(): | |||
print('trial pruned') | |||
raise optuna.exceptions.TrialPruned() | |||
print() | |||
if checkpoint_saving: | |||
model = FakeNewsModel(config).to(config.device) | |||
model.load_state_dict(torch.load(checkpoint_path)) | |||
model.eval() | |||
with torch.no_grad(): | |||
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||
validation_accuracy = multiclass_acc(validation_pred, validation_truth) | |||
if trial and validation_accuracy >= config.wanted_accuracy: | |||
loss_accuracy = pd.DataFrame( | |||
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||
'validation_accuracy': validation_accuracies}) | |||
torch.save({'model_state_dict': model.state_dict(), | |||
'parameters': str(config), | |||
'optimizer_state_dict': optimizer.state_dict(), | |||
'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||
if not checkpoint_saving: | |||
loss_accuracy = pd.DataFrame( | |||
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||
'validation_accuracy': validation_accuracies}) | |||
torch.save(model.state_dict(), checkpoint_path) | |||
if trial and validation_accuracy >= config.wanted_accuracy: | |||
torch.save({'model_state_dict': model.state_dict(), | |||
'parameters': str(config), | |||
'optimizer_state_dict': optimizer.state_dict(), | |||
'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||
try: | |||
del train_loss | |||
del train_truth | |||
del train_pred | |||
del validation_loss | |||
del validation_truth | |||
del validation_pred | |||
del train_losses | |||
del train_accuracies | |||
del validation_losses | |||
del validation_accuracies | |||
del loss_accuracy | |||
del scalar | |||
del model | |||
del params | |||
except: | |||
print('Error in deleting caches') | |||
pass | |||
return validation_accuracy | |||
@@ -0,0 +1,63 @@ | |||
import argparse | |||
from optuna_main import optuna_main | |||
from test_main import test_main | |||
from torch_main import torch_main | |||
from data.twitter.config import TwitterConfig | |||
from data.weibo.config import WeiboConfig | |||
from torch.utils.tensorboard import SummaryWriter | |||
if __name__ == '__main__': | |||
import os | |||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--data', type=str, required=True) | |||
parser.add_argument('--use_optuna', type=int, required=False) | |||
parser.add_argument('--just_test', type=int, required=False) | |||
parser.add_argument('--batch', type=int, required=False) | |||
parser.add_argument('--epoch', type=int, required=False) | |||
parser.add_argument('--extra', type=str, required=False) | |||
args = parser.parse_args() | |||
if args.data == 'twitter': | |||
config = TwitterConfig() | |||
elif args.data == 'weibo': | |||
config = WeiboConfig() | |||
else: | |||
raise Exception('Enter a valid dataset name', args.data) | |||
if args.batch: | |||
config.batch_size = args.batch | |||
if args.epoch: | |||
config.epochs = args.epoch | |||
if args.use_optuna: | |||
config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra) | |||
else: | |||
config.output_path += 'logs/' + args.data + '_' + str(args.extra) | |||
use_optuna = True | |||
if not args.extra or 'temp' in args.extra: | |||
config.output_path = str(args.extra) | |||
use_optuna = False | |||
try: | |||
os.mkdir(config.output_path) | |||
except OSError: | |||
print("Creation of the directory failed") | |||
else: | |||
print("Successfully created the directory") | |||
config.writer = SummaryWriter(config.output_path) | |||
if args.use_optuna and use_optuna: | |||
optuna_main(config, args.use_optuna) | |||
elif args.just_test: | |||
test_main(config, args.just_test) | |||
else: | |||
torch_main(config) | |||
config.writer.close() |
@@ -0,0 +1,143 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from image import ImageTransformerEncoder, ImageResnetEncoder | |||
from text import TextEncoder | |||
class ProjectionHead(nn.Module): | |||
def __init__( | |||
self, | |||
config, | |||
embedding_dim, | |||
): | |||
super().__init__() | |||
self.config = config | |||
self.projection = nn.Linear(embedding_dim, config.projection_size) | |||
self.gelu = nn.GELU() | |||
self.fc = nn.Linear(config.projection_size, config.projection_size) | |||
self.dropout = nn.Dropout(config.dropout) | |||
self.layer_norm = nn.LayerNorm(config.projection_size) | |||
def forward(self, x): | |||
projected = self.projection(x) | |||
x = self.gelu(projected) | |||
x = self.fc(x) | |||
x = self.dropout(x) | |||
x = x + projected | |||
x = self.layer_norm(x) | |||
return x | |||
class Classifier(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size) | |||
self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size) | |||
self.gelu = nn.GELU() | |||
self.drop_out = nn.Dropout(config.dropout) | |||
self.layer_norm_2 = nn.LayerNorm(config.hidden_size) | |||
self.classifier_layer = nn.Linear(config.hidden_size, config.class_num) | |||
self.softmax = nn.Softmax(dim=1) | |||
def forward(self, x): | |||
x = self.layer_norm_1(x) | |||
x = self.linear_layer(x) | |||
x = self.gelu(x) | |||
x = self.drop_out(x) | |||
self.embeddings = x = self.layer_norm_2(x) | |||
x = self.classifier_layer(x) | |||
x = self.softmax(x) | |||
return x | |||
class FakeNewsModel(nn.Module): | |||
def __init__( | |||
self, config | |||
): | |||
super().__init__() | |||
self.config = config | |||
if 'resnet' in self.config.image_model_name: | |||
self.image_encoder = ImageResnetEncoder(config) | |||
else: | |||
self.image_encoder = ImageTransformerEncoder(config) | |||
self.text_encoder = TextEncoder(config) | |||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
self.classifier = Classifier(config) | |||
self.temperature = config.temperature | |||
class_weights = torch.FloatTensor(config.class_weights) | |||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
self.text_embeddings = None | |||
self.image_embeddings = None | |||
self.multimodal_embeddings = None | |||
def forward(self, batch): | |||
image_features = self.image_encoder(ids=batch['id'], image=batch["image"]) | |||
text_features = self.text_encoder(ids=batch['id'], | |||
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |||
self.image_embeddings = self.image_projection(image_features) | |||
self.text_embeddings = self.text_projection(text_features) | |||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
score = self.classifier(self.multimodal_embeddings) | |||
probs, output = torch.max(score.data, dim=1) | |||
return output, score | |||
class FakeNewsEmbeddingModel(nn.Module): | |||
def __init__( | |||
self, config | |||
): | |||
super().__init__() | |||
self.config = config | |||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
self.classifier = Classifier(config) | |||
self.temperature = config.temperature | |||
class_weights = torch.FloatTensor(config.class_weights) | |||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
def forward(self, batch): | |||
self.image_embeddings = self.image_projection(batch['image_embedding']) | |||
self.text_embeddings = self.text_projection(batch['text_embedding']) | |||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
score = self.classifier(self.multimodal_embeddings) | |||
probs, output = torch.max(score.data, dim=1) | |||
return output, score | |||
def calculate_loss(model, score, label): | |||
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||
# fake = (label == 1).nonzero() | |||
# real = (label == 0).nonzero() | |||
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean() | |||
c_loss = model.classifier_loss_function(score, label) | |||
loss = model.config.loss_weight * c_loss + similarity.mean() | |||
return loss | |||
def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||
# Calculating the Loss | |||
logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||
images_similarity = image_embeddings @ image_embeddings.T | |||
texts_similarity = text_embeddings @ text_embeddings.T | |||
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||
texts_loss = cross_entropy(logits, targets, reduction='none') | |||
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||
loss = (images_loss + texts_loss) / 2.0 | |||
return loss | |||
def cross_entropy(preds, targets, reduction='none'): | |||
log_softmax = nn.LogSoftmax(dim=-1) | |||
loss = (-targets * log_softmax(preds)).sum(1) | |||
if reduction == "none": | |||
return loss | |||
elif reduction == "mean": | |||
return loss.mean() |
@@ -0,0 +1,62 @@ | |||
import joblib | |||
import optuna | |||
from optuna.pruners import MedianPruner | |||
from optuna.trial import TrialState | |||
from data_loaders import make_dfs, build_loaders | |||
from learner import supervised_train | |||
from test_main import test | |||
def objective(trial, config, train_loader, validation_loader): | |||
config.optuna(trial=trial) | |||
print('Trial', trial.number, 'parameters', trial.params) | |||
accuracy = supervised_train(config, train_loader, validation_loader, trial=trial) | |||
return accuracy | |||
def optuna_main(config, n_trials=100): | |||
train_df, test_df, validation_df, _ = make_dfs(config) | |||
train_loader = build_loaders(config, train_df, mode="train") | |||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||
test_loader = build_loaders(config, test_df, mode="test") | |||
study = optuna.create_study(study_name=config.output_path.split('/')[-1], | |||
sampler=optuna.samplers.TPESampler(), | |||
storage=f'sqlite:///{config.output_path + "/optuna.db"}', | |||
load_if_exists=True, | |||
direction="maximize", | |||
pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10) | |||
) | |||
study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials) | |||
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) | |||
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) | |||
joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl') | |||
print("Study statistics: ") | |||
print(" Number of finished trials: ", len(study.trials)) | |||
print(" Number of pruned trials: ", len(pruned_trials)) | |||
print(" Number of complete trials: ", len(complete_trials)) | |||
s = '' | |||
print("Best trial:") | |||
trial = study.best_trial | |||
print(' Number: ', trial.number) | |||
print(" Value: ", trial.value) | |||
s += 'value: ' + str(trial.value) + '\n' | |||
print(" Params: ") | |||
s += 'params: \n' | |||
for key, value in trial.params.items(): | |||
print(" {}: {}".format(key, value)) | |||
s += " {}: {}\n".format(key, value) | |||
with open(config.output_path+'/optuna_results.txt', 'w') as f: | |||
f.write(s) | |||
test(config, test_loader, trial_number=trial.number) | |||
@@ -0,0 +1,17 @@ | |||
opencv-python==4.5.1.48 | |||
optuna==2.8.0 | |||
pandas==1.1.5 | |||
Pillow==8.2.0 | |||
pytorch-lightning==1.2.10 | |||
pytorch-model-summary==0.1.2 | |||
pytorchtools==0.0.2 | |||
scikit-image==0.17.2 | |||
scikit-learn==0.24.2 | |||
scipy==1.5.4 | |||
seaborn==0.11.1 | |||
torch==1.8.1 | |||
torchmetrics==0.2.0 | |||
torchsummary==1.5.1 | |||
torchvision==0.9.1 | |||
tqdm==4.60.0 | |||
transformers==4.5.1 |
@@ -0,0 +1,111 @@ | |||
import random | |||
import numpy as np | |||
import torch | |||
from tqdm import tqdm | |||
from data_loaders import make_dfs, build_loaders | |||
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca | |||
from learner import batch_constructor | |||
from model import FakeNewsModel | |||
def test(config, test_loader, trial_number=None): | |||
if trial_number: | |||
try: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
except: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
else: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
try: | |||
parameters = checkpoint['parameters'] | |||
config.assign_hyperparameters(parameters) | |||
except: | |||
pass | |||
model = FakeNewsModel(config).to(config.device) | |||
try: | |||
model.load_state_dict(checkpoint['model_state_dict']) | |||
except: | |||
model.load_state_dict(checkpoint) | |||
model.eval() | |||
torch.manual_seed(27) | |||
random.seed(27) | |||
np.random.seed(27) | |||
image_features = [] | |||
text_features = [] | |||
multimodal_features = [] | |||
concat_features = [] | |||
targets = [] | |||
predictions = [] | |||
scores = [] | |||
tqdm_object = tqdm(test_loader, total=len(test_loader)) | |||
for i, batch in enumerate(tqdm_object): | |||
batch = batch_constructor(config, batch) | |||
with torch.no_grad(): | |||
output, score = model(batch) | |||
prediction = output.detach() | |||
predictions.append(prediction) | |||
score = score.detach() | |||
scores.append(score) | |||
target = batch['label'].detach() | |||
targets.append(target) | |||
image_feature = model.image_embeddings.detach() | |||
image_features.append(image_feature) | |||
text_feature = model.text_embeddings.detach() | |||
text_features.append(text_feature) | |||
multimodal_feature = model.multimodal_embeddings.detach() | |||
multimodal_features.append(multimodal_feature) | |||
concat_feature = model.classifier.embeddings.detach() | |||
concat_features.append(concat_feature) | |||
# config.writer.add_graph(model, input_to_model=batch, verbose=True) | |||
s = '' | |||
s += report_per_class(targets, predictions) + '\n' | |||
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||
with open(config.output_path + '/results.txt', 'w') as f: | |||
f.write(s) | |||
roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||
# saving_in_tensorboard(config, image_features, targets, 'image_features') | |||
plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png') | |||
plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png') | |||
# saving_in_tensorboard(config, text_features, targets, 'text_features') | |||
plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png') | |||
plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png') | |||
# | |||
# saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features') | |||
plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png') | |||
plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png') | |||
# saving_in_tensorboard(config, concat_features, targets, 'concat_features') | |||
plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png') | |||
plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png') | |||
config_parameters = str(config) | |||
with open(config.output_path + '/parameters.txt', 'w') as f: | |||
f.write(config_parameters) | |||
print(config) | |||
def test_main(config, trial_number=None): | |||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
test_loader = build_loaders(config, test_df, mode="test") | |||
test(config, test_loader, trial_number) |
@@ -0,0 +1,47 @@ | |||
from torch import nn | |||
from transformers import BertModel, BertConfig, BertTokenizer, \ | |||
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||
XLNetModel, XLNetConfig, XLNetTokenizer | |||
class TextEncoder(nn.Module): | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
self.model = get_text_model(config) | |||
for p in self.model.parameters(): | |||
p.requires_grad = self.config.trainable | |||
self.target_token_idx = 0 | |||
self.text_encoder_embedding = dict() | |||
def forward(self, ids, input_ids, attention_mask): | |||
output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |||
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||
# for i, id in enumerate(ids): | |||
# id = int(id.detach().cpu().numpy()) | |||
# self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||
return last_hidden_state | |||
def get_text_model(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
if config.pretrained: | |||
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||
else: | |||
model = BigBirdModel(config=BigBirdConfig()) | |||
elif 'xlnet' in config.text_encoder_model: | |||
if config.pretrained: | |||
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
output_hidden_states=True, return_dict=True) | |||
else: | |||
model = XLNetModel(config=XLNetConfig()) | |||
else: | |||
if config.pretrained: | |||
model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
output_hidden_states=True, return_dict=True) | |||
else: | |||
model = BertModel(config=BertConfig()) | |||
return model | |||
@@ -0,0 +1,19 @@ | |||
from data_loaders import build_loaders, make_dfs | |||
from learner import supervised_train | |||
from test_main import test | |||
def torch_main(config): | |||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
train_loader = build_loaders(config, train_df, mode="train") | |||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||
test_loader = build_loaders(config, test_df, mode="test") | |||
supervised_train(config, train_loader, validation_loader) | |||
test(config, test_loader) | |||
@@ -0,0 +1,93 @@ | |||
import numpy as np | |||
import torch | |||
class AvgMeter: | |||
def __init__(self, name="Metric"): | |||
self.name = name | |||
self.reset() | |||
def reset(self): | |||
self.avg, self.sum, self.count = [0] * 3 | |||
def update(self, val, count=1): | |||
self.count += count | |||
self.sum += val * count | |||
self.avg = self.sum / self.count | |||
def __repr__(self): | |||
text = f"{self.name}: {self.avg:.4f}" | |||
return text | |||
def print_lr(optimizer): | |||
for param_group in optimizer.param_groups: | |||
print(param_group['name'], param_group['lr']) | |||
class CheckpointSaving: | |||
def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print): | |||
self.best_score = None | |||
self.val_acc_max = 0 | |||
self.path = path | |||
self.verbose = verbose | |||
self.trace_func = trace_func | |||
def __call__(self, val_acc, model): | |||
if self.best_score is None: | |||
self.best_score = val_acc | |||
self.save_checkpoint(val_acc, model) | |||
elif val_acc > self.best_score: | |||
self.best_score = val_acc | |||
self.save_checkpoint(val_acc, model) | |||
def save_checkpoint(self, val_acc, model): | |||
if self.verbose: | |||
self.trace_func( | |||
f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...') | |||
torch.save(model.state_dict(), self.path) | |||
self.val_acc_max = val_acc | |||
class EarlyStopping: | |||
def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print): | |||
self.patience = patience | |||
self.verbose = verbose | |||
self.counter = 0 | |||
self.best_score = None | |||
self.early_stop = False | |||
self.val_loss_min = np.Inf | |||
self.delta = delta | |||
self.path = path | |||
self.trace_func = trace_func | |||
def __call__(self, val_loss, model): | |||
score = -val_loss | |||
if self.best_score is None: | |||
self.best_score = score | |||
if self.verbose: | |||
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||
self.val_loss_min = val_loss | |||
# self.save_checkpoint(val_loss, model) | |||
elif score < self.best_score + self.delta: | |||
self.counter += 1 | |||
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |||
if self.counter >= self.patience: | |||
self.early_stop = True | |||
else: | |||
self.best_score = score | |||
if self.verbose: | |||
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||
# self.save_checkpoint(val_loss, model) | |||
self.val_loss_min = val_loss | |||
self.counter = 0 | |||
# def save_checkpoint(self, val_loss, model): | |||
# if self.verbose: | |||
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||
# torch.save(model.state_dict(), self.path) | |||
# self.val_loss_min = val_loss |