Browse Source

codes added

master
FaezeGhorbanpour 8 months ago
parent
commit
0719ba7a98
20 changed files with 1628 additions and 1 deletions
  1. 3
    0
      .gitignore
  2. 46
    1
      README.md
  3. 117
    0
      data/config.py
  4. 33
    0
      data/data_loader.py
  5. 65
    0
      data/twitter/config.py
  6. 65
    0
      data/twitter/data_loader.py
  7. 62
    0
      data/weibo/config.py
  8. 66
    0
      data/weibo/data_loader.py
  9. 108
    0
      data_loaders.py
  10. 257
    0
      evaluation.py
  11. 45
    0
      image.py
  12. 206
    0
      learner.py
  13. 63
    0
      main.py
  14. 143
    0
      model.py
  15. 62
    0
      optuna_main.py
  16. 17
    0
      requirements.txt
  17. 111
    0
      test_main.py
  18. 47
    0
      text.py
  19. 19
    0
      torch_main.py
  20. 93
    0
      utils.py

+ 3
- 0
.gitignore View File

@@ -114,3 +114,6 @@ dmypy.json
# Pyre type checker
.pyre/

.idea/
env/
venv/

+ 46
- 1
README.md View File

@@ -1,3 +1,48 @@
# FakeNewsRevealer

Official implementation of the Fake News Revealer paper
Official implementation of the "Fake News Revealer (FNR): A Similarity and Transformer-Based Approach to Detect
Multi-Modal Fake News in Social Media" paper [(ArXiv Link)](https://arxiv.org/pdf/2112.01131.pdf).


## Requirments
FNR is built in Python 3.6 using PyTorch 1.8. Please use the following command to install the requirements:

```
pip install -r requirements.txt
```

## How to Run
First, place the data address and configuration into the config file in the data directory, and then follow the train
and test commands.

### train
To run with Optuna for parameter tuning use this command:

```
python main --data "DATA NAME" --use_optuna "NUMBER OF OPTUNA TRIALS" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER"
```

To run without parameter tuning, adjust your parameters in the config file and then use the below command:
```
python main --data "DATA NAME" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER"
```

### test
In the test step, at first, make sure to have the requested 'checkpoint' file then run the following line:
```
python main --data "DATA NAME" --just_test "REQESTED TRIAL NUMBER"
```

## Bibtex
Cite our paper using the following bibtex item:

```
@misc{ghorbanpour2021fnr,
title={FNR: A Similarity and Transformer-Based Approach to Detect Multi-Modal Fake News in Social Media},
author={Faeze Ghorbanpour and Maryam Ramezani and Mohammad A. Fazli and Hamid R. Rabiee},
year={2021},
eprint={2112.01131},
archivePrefix={arXiv},
}
```


+ 117
- 0
data/config.py View File

@@ -0,0 +1,117 @@
import torch
import ast


class Config:
name = ''
DatasetLoader = None
debug = False
data_path = ''
output_path = ''

train_image_path = ''
validation_image_path = ''
test_image_path = ''

train_text_path = ''
validation_text_path = ''
test_text_path = ''

train_text_embedding_path = ''
validation_text_embedding_path = ''

train_image_embedding_path = ''
validation_image_embedding_path = ''

batch_size = 256
epochs = 50
num_workers = 2
head_lr = 1e-03
image_encoder_lr = 1e-04
text_encoder_lr = 1e-04
attention_lr = 1e-3
classification_lr = 1e-03

max_grad_norm = 5.0

head_weight_decay = 0.001
attention_weight_decay = 0.001
classification_weight_decay = 0.001

patience = 30
delta = 0.0000001
factor = 0.8

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768
num_img_region = 64 # 16 #TODO
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"

text_embedding = 768
max_length = 32

pretrained = True # for both image encoder and text encoder
trainable = False # for both image encoder and text encoder
temperature = 1.0

# image size
size = 224

num_projection_layers = 1
projection_size = 256
dropout = 0.3
hidden_size = 256
num_region = 64 # 16 #TODO
region_size = projection_size // num_region

classes = ['real', 'fake']
class_num = 2

loss_weight = 1
class_weights = [1, 1]

writer = None

has_unlabeled_data = False
step = 0
T1 = 10
T2 = 150
af = 3

wanted_accuracy = 0.76

def optuna(self, trial):
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)
self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1)

self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])

def __str__(self):
s = ''
members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
for member in members:
s += member + '\t' + str(getattr(self, member)) + '\n'
return s

def assign_hyperparameters(self, s):
for line in s.split('\n'):
s = line.split('\t')
try:
attr = getattr(self, s[0])
if type(attr) not in [list, set, dict]:
setattr(self, s[0], type(attr)(s[1]))
else:
setattr(self, s[0], ast.literal_eval(s[1]))

except:
print(s[0], 'doesnot exist')

+ 33
- 0
data/data_loader.py View File

@@ -0,0 +1,33 @@
import torch
from torch.utils.data import Dataset

class DatasetLoader(Dataset):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer(
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt'
)
self.labels = labels
self.transforms = transforms
self.mode = mode

def set_text(self, idx):
item = {
key: values[idx].clone().detach()
for key, values in self.encoded_text.items()
}
item['text'] = self.text[idx]
item['label'] = self.labels[idx]
item['id'] = idx
return item

def set_image(self, image, item):
if 'resnet' in self.config.image_model_name:
image = self.transforms(image=image)['image']
item['image'] = torch.as_tensor(image).reshape((3, 224, 224))
else:
image = self.transforms(images=image, return_tensors='pt')
image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
item['image'] = image.reshape((3, 224, 224))

+ 65
- 0
data/twitter/config.py View File

@@ -0,0 +1,65 @@
import torch

from data.config import Config
from data.twitter.data_loader import TwitterDatasetLoader


class TwitterConfig(Config):
name = 'twitter'
DatasetLoader = TwitterDatasetLoader

data_path = '/twitter/'
output_path = ''

train_image_path = data_path + 'images_train/'
validation_image_path = data_path + 'images_validation/'
test_image_path = data_path + 'images_test/'

train_text_path = data_path + 'twitter_train_translated.csv'
validation_text_path = data_path + 'twitter_validation_translated.csv'
test_text_path = data_path + 'twitter_test_translated.csv'

batch_size = 128
epochs = 100
num_workers = 2
head_lr = 1e-03
image_encoder_lr = 1e-04
text_encoder_lr = 1e-04
attention_lr = 1e-3
classification_lr = 1e-03

head_weight_decay = 0.001
attention_weight_decay = 0.001
classification_weight_decay = 0.001

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_model_name = 'vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "bert-base-uncased"
text_tokenizer = "bert-base-uncased"
text_embedding = 768
max_length = 32

pretrained = True
trainable = False
temperature = 1.0

classes = ['real', 'fake']
class_weights = [1, 1]

wanted_accuracy = 0.76

def optuna(self, trial):
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)

self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

# self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])

+ 65
- 0
data/twitter/data_loader.py View File

@@ -0,0 +1,65 @@

import cv2
import pickle
from PIL import Image
from numpy import asarray

from data.data_loader import DatasetLoader


class TwitterDatasetLoader(DatasetLoader):
def __getitem__(self, idx):
item = self.set_text(idx)

try:
if self.mode == 'train':
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}")
elif self.mode == 'validation':
image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}")
else:
image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except:
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image)
self.set_image(image, item)
return item


def __len__(self):
return len(self.text)


class TwitterEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 62
- 0
data/weibo/config.py View File

@@ -0,0 +1,62 @@
import torch
from transformers import BertTokenizer, BertModel, BertConfig

from data.config import Config
from data.weibo.data_loader import WeiboDatasetLoader


class WeiboConfig(Config):
name = 'weibo'
DatasetLoader = WeiboDatasetLoader

data_path = 'weibo/'
output_path = ''

rumor_image_path = data_path + 'rumor_images/'
nonrumor_image_path = data_path + 'nonrumor_images/'

train_text_path = data_path + 'weibo_train.csv'
validation_text_path = data_path + 'weibo_validation.csv'
test_text_path = data_path + 'weibo_test.csv'

batch_size = 128
epochs = 100
num_workers = 2
head_lr = 1e-03
image_encoder_lr = 1e-02
text_encoder_lr = 1e-05
weight_decay = 0.001
classification_lr = 1e-02

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_model_name = 'vit-base-patch16-224' # 'resnet101'
image_embedding = 768 # 2048
num_img_region = 64 # TODO
text_encoder_model = "bert-base-chinese"
text_tokenizer = "bert-base-chinese"
text_embedding = 768
max_length = 200

pretrained = True
trainable = False
temperature = 1.0

labels = ['real', 'fake']

wanted_accuracy = 0.80

def optuna(self, trial):
self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)

self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])


+ 66
- 0
data/weibo/data_loader.py View File

@@ -0,0 +1,66 @@

import pickle
import cv2
from PIL import Image
from numpy import asarray

from data.data_loader import DatasetLoader



class WeiboDatasetLoader(DatasetLoader):
def __getitem__(self, idx):
item = self.set_text(idx)

if self.labels[idx] == 1:
try:
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}")
except:
image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image)
else:
try:
image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}")
except:
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image)
self.set_image(image, item)
return item

def __len__(self):
return len(self.text)


class WeiboEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 108
- 0
data_loaders.py View File

@@ -0,0 +1,108 @@
import albumentations as A
import numpy as np
import pandas as pd

from sklearn.utils import compute_class_weight
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer


def get_transforms(config, mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)


def make_dfs(config):
train_dataframe = pd.read_csv(config.train_text_path)
train_dataframe.dropna(subset=['text'], inplace=True)
train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True)
train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x))
config.class_weights = get_class_weights(train_dataframe.label.values)

if config.test_text_path is None:
offset = int(train_dataframe.shape[0] * 0.80)
test_dataframe = train_dataframe[offset:]
train_dataframe = train_dataframe[:offset]
else:
test_dataframe = pd.read_csv(config.test_text_path)
test_dataframe.dropna(subset=['text'], inplace=True)
test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True)
test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))

if config.validation_text_path is None:
offset = int(train_dataframe.shape[0] * 0.90)
validation_dataframe = train_dataframe[offset:]
train_dataframe = train_dataframe[:offset]
else:
validation_dataframe = pd.read_csv(config.validation_text_path)
validation_dataframe.dropna(subset=['text'], inplace=True)
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))

unlabeled_dataframe = None
if config.has_unlabeled_data:
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path)
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True)

return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
return tokenizer


def build_loaders(config, dataframe, mode):
if 'resnet' in config.image_model_name:
transforms = get_transforms(config, mode=mode)
else:
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)

dataset = config.DatasetLoader(
config,
dataframe["image"].values,
dataframe["text"].values,
dataframe["label"].values,
tokenizer=get_tokenizer(config),
transforms=transforms,
mode=mode,
)
if mode != 'train':
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
pin_memory=False,
shuffle=False,
)
else:
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
pin_memory=True,
shuffle=True,
)
return dataloader


def get_class_weights(y):
class_weights = compute_class_weight('balanced', np.unique(y), y)
return class_weights

+ 257
- 0
evaluation.py View File

@@ -0,0 +1,257 @@
from itertools import cycle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from numpy import interp
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score
from sklearn.metrics import classification_report, roc_curve, auc
from sklearn.preprocessing import OneHotEncoder


def metrics(truth, pred, prob, file_path):
truth = [i.cpu().numpy() for i in truth]
pred = [i.cpu().numpy() for i in pred]
prob = [i.cpu().numpy() for i in prob]

pred = np.concatenate(pred, axis=0)
truth = np.concatenate(truth, axis=0)
prob = np.concatenate(prob, axis=0)
prob = prob[:, 1]

f_score_micro = f1_score(truth, pred, average='micro', zero_division=0)
f_score_macro = f1_score(truth, pred, average='macro', zero_division=0)
f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0)
accuarcy = accuracy_score(truth, pred)

s = ''
print('accuracy', accuarcy)
s += 'accuracy' + str(accuarcy) + '\n'
print('f_score_micro', f_score_micro)
s += 'f_score_micro' + str(f_score_micro) + '\n'
print('f_score_macro', f_score_macro)
s += 'f_score_macro' + str(f_score_macro) + '\n'
print('f_score_weighted', f_score_weighted)
s += 'f_score_weighted' + str(f_score_weighted) + '\n'

fpr, tpr, thresholds = roc_curve(truth, prob)
AUC = auc(fpr, tpr)
print('AUC', AUC)
s += 'AUC' + str(AUC) + '\n'
df = pd.DataFrame(dict(fpr=fpr, tpr=tpr))
df.to_csv(file_path)

return s


def report_per_class(truth, pred):
truth = [i.cpu().numpy() for i in truth]
pred = [i.cpu().numpy() for i in pred]

pred = np.concatenate(pred, axis=0)
truth = np.concatenate(truth, axis=0)

report = classification_report(truth, pred, zero_division=0, output_dict=True)

s = ''
class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']]
for class_label in class_labels:
print('class_label', class_label)
s += 'class_label' + str(class_label) + '\n'
s += str(report[class_label])
print(report[class_label])

return s


def multiclass_acc(truth, pred):
truth = [i.cpu().numpy() for i in truth]
pred = [i.cpu().numpy() for i in pred]

pred = np.concatenate(pred, axis=0)
truth = np.concatenate(truth, axis=0)

return accuracy_score(truth, pred)


def roc_auc_plot(truth, score, num_class=2, fname='roc.png'):
truth = [i.cpu().numpy() for i in truth]
score = [i.cpu().numpy() for i in score]

truth = np.concatenate(truth, axis=0)
score = np.concatenate(score, axis=0)

enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(truth.reshape(-1, 1))
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray()

fpr_dict = dict()
tpr_dict = dict()
roc_auc_dict = dict()
for i in range(num_class):
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i])
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
# micro
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel())
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])

# macro
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
mean_tpr /= num_class
fpr_dict["macro"] = all_fpr
tpr_dict["macro"] = mean_tpr
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])

plt.figure()

lw = 2
plt.plot(fpr_dict["micro"], tpr_dict["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["micro"]),
color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr_dict["macro"], tpr_dict["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["macro"]),
color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(num_class), colors):
plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc_dict[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.savefig(fname)
# plt.show()


def precision_recall_plot(truth, score, num_class=2, fname='pr.png'):
truth = [i.cpu().numpy() for i in truth]
score = [i.cpu().numpy() for i in score]

truth = np.concatenate(truth, axis=0)
score = np.concatenate(score, axis=0)

enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(truth.reshape(-1, 1))
label_onehot = enc.transform(truth.reshape(-1, 1)).toarray()

# Call the Sklearn library, calculate the precision and recall corresponding to each category
precision_dict = dict()
recall_dict = dict()
average_precision_dict = dict()
for i in range(num_class):
precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i])
average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i])
print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i])

# micro
precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(),
score.ravel())
average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro")

# macro
all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i])
mean_tpr /= num_class
precision_dict["macro"] = all_fpr
recall_dict["macro"] = mean_tpr
average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"])

plt.figure()
plt.subplots(figsize=(16, 10))
lw = 2
plt.plot(precision_dict["micro"], recall_dict["micro"],
label='micro-average Precision-Recall curve (area = {0:0.2f})'
''.format(average_precision_dict["micro"]),
color='deeppink', linestyle=':', linewidth=4)

plt.plot(precision_dict["macro"], recall_dict["macro"],
label='macro-average Precision-Recall curve (area = {0:0.2f})'
''.format(average_precision_dict["macro"]),
color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(num_class), colors):
plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw,
label='Precision-Recall curve of class {0} (area = {1:0.2f})'
''.format(i, average_precision_dict[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])

plt.legend(loc="lower left")
plt.savefig(fname=fname)
# plt.show()


def saving_in_tensorboard(config, x, y, fname='embedding'):
x = [i.cpu().numpy() for i in x]
y = [i.cpu().numpy() for i in y]

x = np.concatenate(x, axis=0)
y = np.concatenate(y, axis=0)
z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values

# config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname)

def plot_tsne(config, x, y, fname='tsne.png'):
x = [i.cpu().numpy() for i in x]
y = [i.cpu().numpy() for i in y]

x = np.concatenate(x, axis=0)
y = np.concatenate(y, axis=0)

y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values

tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000)
tsne_proj = tsne.fit_transform(x)

fig, ax = plt.subplots(figsize=(16, 10))

palette = sns.color_palette("bright", 2)
sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette)

ax.legend(fontsize='large', markerscale=2)
plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0]))
plt.savefig(fname=fname)
# plt.show()


def plot_pca(config, x, y, fname='pca.png'):
x = [i.cpu().numpy() for i in x]
y = [i.cpu().numpy() for i in y]

x = np.concatenate(x, axis=0)
y = np.concatenate(y, axis=0)

y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values

pca = PCA(n_components=2)
pca_proj = pca.fit_transform(x)

fig, ax = plt.subplots(figsize=(16, 10))

palette = sns.color_palette("bright", 2)
sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette)

ax.legend(fontsize='large', markerscale=2)
plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0]))
plt.savefig(fname=fname)
# plt.show()

+ 45
- 0
image.py View File

@@ -0,0 +1,45 @@
import timm
from torch import nn
from transformers import ViTModel, ViTConfig


class ImageTransformerEncoder(nn.Module):

def __init__(self, config):
super().__init__()
if config.pretrained:
self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False,
output_hidden_states=True, return_dict=True)
else:
self.model = ViTModel(config=ViTConfig())

for p in self.model.parameters():
p.requires_grad = config.trainable

self.target_token_idx = 0
self.image_encoder_embedding = dict()

def forward(self, ids, image):
output = self.model(image)
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]

# for i, id in enumerate(ids):
# id = int(id.detach().cpu().numpy())
# self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
return last_hidden_state


class ImageResnetEncoder(nn.Module):

def __init__(
self, config
):
super().__init__()
self.model = timm.create_model(
config.image_model_name, config.pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = config.trainable

def forward(self, ids, image):
return self.model(image)

+ 206
- 0
learner.py View File

@@ -0,0 +1,206 @@
import gc
import itertools
import random

import numpy as np
import optuna
import pandas as pd
import torch

from evaluation import multiclass_acc
from model import FakeNewsModel, calculate_loss
from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving


def batch_constructor(config, batch):
b = {}
for k, v in batch.items():
if k != 'text':
b[k] = v.to(config.device)
return b


def train_epoch(config, model, train_loader, optimizer, scalar):
loss_meter = AvgMeter('train')
# tqdm_object = tqdm(train_loader, total=len(train_loader))

targets = []
predictions = []
for index, batch in enumerate(train_loader):
batch = batch_constructor(config, batch)
optimizer.zero_grad(set_to_none=True)

with torch.cuda.amp.autocast():
output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])
scalar.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

if (index + 1) % 2:
scalar.step(optimizer)
# loss.backward()
# optimizer.step()
scalar.update()

count = batch["id"].size(0)
loss_meter.update(loss.detach(), count)

# tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))

prediction = output.detach()
predictions.append(prediction)

target = batch['label'].detach()
targets.append(target)

return loss_meter, targets, predictions


def validation_epoch(config, model, validation_loader):
loss_meter = AvgMeter('validation')

targets = []
predictions = []
# tqdm_object = tqdm(validation_loader, total=len(validation_loader))
for batch in validation_loader:
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)
loss = calculate_loss(model, score, batch['label'])

count = batch["id"].size(0)
loss_meter.update(loss.detach(), count)

# tqdm_object.set_postfix(validation_loss=loss_meter.avg)

prediction = output.detach()
predictions.append(prediction)

target = batch['label'].detach()
targets.append(target)

return loss_meter, targets, predictions


def supervised_train(config, train_loader, validation_loader, trial=None):
torch.cuda.empty_cache()
checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt'
if trial:
checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt'

torch.manual_seed(27)
random.seed(27)
np.random.seed(27)
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

scalar = torch.cuda.amp.GradScaler()
model = FakeNewsModel(config).to(config.device)

params = [
{"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'},
{"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'},
{"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()),
"lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'},
{"params": model.classifier.parameters(), "lr": config.classification_lr,
"weight_decay": config.classification_weight_decay,
'name': 'classifier'}
]
optimizer = torch.optim.AdamW(params, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor,
patience=config.patience // 5, verbose=True)
early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True)
checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True)

train_losses, train_accuracies = [], []
validation_losses, validation_accuracies = [], []

validation_accuracy, validation_loss = 0, 1
for epoch in range(config.epochs):
print(f"Epoch: {epoch + 1}")
gc.collect()

model.train()
train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar)
model.eval()
with torch.no_grad():
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)

train_accuracy = multiclass_acc(train_truth, train_pred)
validation_accuracy = multiclass_acc(validation_truth, validation_pred)
print_lr(optimizer)
print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy)
print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)

if lr_scheduler:
lr_scheduler.step(validation_loss.avg)
if early_stopping:
early_stopping(validation_loss.avg, model)
if early_stopping.early_stop:
print("Early stopping")
break
if checkpoint_saving:
checkpoint_saving(validation_accuracy, model)

train_accuracies.append(train_accuracy)
train_losses.append(train_loss)
validation_accuracies.append(validation_accuracy)
validation_losses.append(validation_loss)

if trial:
trial.report(validation_accuracy, epoch)
if trial.should_prune():
print('trial pruned')
raise optuna.exceptions.TrialPruned()

print()

if checkpoint_saving:
model = FakeNewsModel(config).to(config.device)
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
with torch.no_grad():
validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
validation_accuracy = multiclass_acc(validation_pred, validation_truth)
if trial and validation_accuracy >= config.wanted_accuracy:
loss_accuracy = pd.DataFrame(
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
'validation_accuracy': validation_accuracies})
torch.save({'model_state_dict': model.state_dict(),
'parameters': str(config),
'optimizer_state_dict': optimizer.state_dict(),
'loss_accuracy': loss_accuracy}, checkpoint_path2)

if not checkpoint_saving:
loss_accuracy = pd.DataFrame(
{'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
'validation_accuracy': validation_accuracies})
torch.save(model.state_dict(), checkpoint_path)
if trial and validation_accuracy >= config.wanted_accuracy:
torch.save({'model_state_dict': model.state_dict(),
'parameters': str(config),
'optimizer_state_dict': optimizer.state_dict(),
'loss_accuracy': loss_accuracy}, checkpoint_path2)
try:
del train_loss
del train_truth
del train_pred
del validation_loss
del validation_truth
del validation_pred
del train_losses
del train_accuracies
del validation_losses
del validation_accuracies
del loss_accuracy
del scalar
del model
del params
except:
print('Error in deleting caches')
pass

return validation_accuracy


+ 63
- 0
main.py View File

@@ -0,0 +1,63 @@
import argparse

from optuna_main import optuna_main
from test_main import test_main
from torch_main import torch_main
from data.twitter.config import TwitterConfig
from data.weibo.config import WeiboConfig
from torch.utils.tensorboard import SummaryWriter

if __name__ == '__main__':
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, required=True)
parser.add_argument('--use_optuna', type=int, required=False)
parser.add_argument('--just_test', type=int, required=False)
parser.add_argument('--batch', type=int, required=False)
parser.add_argument('--epoch', type=int, required=False)
parser.add_argument('--extra', type=str, required=False)

args = parser.parse_args()

if args.data == 'twitter':
config = TwitterConfig()
elif args.data == 'weibo':
config = WeiboConfig()
else:
raise Exception('Enter a valid dataset name', args.data)

if args.batch:
config.batch_size = args.batch
if args.epoch:
config.epochs = args.epoch

if args.use_optuna:
config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra)
else:
config.output_path += 'logs/' + args.data + '_' + str(args.extra)

use_optuna = True
if not args.extra or 'temp' in args.extra:
config.output_path = str(args.extra)
use_optuna = False

try:
os.mkdir(config.output_path)
except OSError:
print("Creation of the directory failed")
else:
print("Successfully created the directory")

config.writer = SummaryWriter(config.output_path)

if args.use_optuna and use_optuna:
optuna_main(config, args.use_optuna)
elif args.just_test:
test_main(config, args.just_test)
else:
torch_main(config)

config.writer.close()

+ 143
- 0
model.py View File

@@ -0,0 +1,143 @@
import torch
import torch.nn.functional as F
from torch import nn

from image import ImageTransformerEncoder, ImageResnetEncoder
from text import TextEncoder


class ProjectionHead(nn.Module):
def __init__(
self,
config,
embedding_dim,
):
super().__init__()
self.config = config
self.projection = nn.Linear(embedding_dim, config.projection_size)
self.gelu = nn.GELU()
self.fc = nn.Linear(config.projection_size, config.projection_size)
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.LayerNorm(config.projection_size)

def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x


class Classifier(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size)
self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size)
self.gelu = nn.GELU()
self.drop_out = nn.Dropout(config.dropout)
self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
self.classifier_layer = nn.Linear(config.hidden_size, config.class_num)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.layer_norm_1(x)
x = self.linear_layer(x)
x = self.gelu(x)
x = self.drop_out(x)
self.embeddings = x = self.layer_norm_2(x)
x = self.classifier_layer(x)
x = self.softmax(x)
return x


class FakeNewsModel(nn.Module):
def __init__(
self, config
):
super().__init__()
self.config = config
if 'resnet' in self.config.image_model_name:
self.image_encoder = ImageResnetEncoder(config)
else:
self.image_encoder = ImageTransformerEncoder(config)
self.text_encoder = TextEncoder(config)
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

self.text_embeddings = None
self.image_embeddings = None
self.multimodal_embeddings = None

def forward(self, batch):
image_features = self.image_encoder(ids=batch['id'], image=batch["image"])
text_features = self.text_encoder(ids=batch['id'],
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
self.image_embeddings = self.image_projection(image_features)
self.text_embeddings = self.text_projection(text_features)

self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score


class FakeNewsEmbeddingModel(nn.Module):
def __init__(
self, config
):
super().__init__()
self.config = config
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

def forward(self, batch):
self.image_embeddings = self.image_projection(batch['image_embedding'])
self.text_embeddings = self.text_projection(batch['text_embedding'])

self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score


def calculate_loss(model, score, label):
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings)
# fake = (label == 1).nonzero()
# real = (label == 0).nonzero()
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean()
c_loss = model.classifier_loss_function(score, label)
loss = model.config.loss_weight * c_loss + similarity.mean()
return loss


def calculate_similarity_loss(config, image_embeddings, text_embeddings):
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / config.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0
return loss


def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()

+ 62
- 0
optuna_main.py View File

@@ -0,0 +1,62 @@


import joblib
import optuna
from optuna.pruners import MedianPruner
from optuna.trial import TrialState

from data_loaders import make_dfs, build_loaders
from learner import supervised_train
from test_main import test


def objective(trial, config, train_loader, validation_loader):
config.optuna(trial=trial)
print('Trial', trial.number, 'parameters', trial.params)
accuracy = supervised_train(config, train_loader, validation_loader, trial=trial)
return accuracy


def optuna_main(config, n_trials=100):
train_df, test_df, validation_df, _ = make_dfs(config)
train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test")

study = optuna.create_study(study_name=config.output_path.split('/')[-1],
sampler=optuna.samplers.TPESampler(),
storage=f'sqlite:///{config.output_path + "/optuna.db"}',
load_if_exists=True,
direction="maximize",
pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10)
)
study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl')

print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))

s = ''
print("Best trial:")
trial = study.best_trial
print(' Number: ', trial.number)
print(" Value: ", trial.value)
s += 'value: ' + str(trial.value) + '\n'

print(" Params: ")
s += 'params: \n'
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
s += " {}: {}\n".format(key, value)

with open(config.output_path+'/optuna_results.txt', 'w') as f:
f.write(s)

test(config, test_loader, trial_number=trial.number)


+ 17
- 0
requirements.txt View File

@@ -0,0 +1,17 @@
opencv-python==4.5.1.48
optuna==2.8.0
pandas==1.1.5
Pillow==8.2.0
pytorch-lightning==1.2.10
pytorch-model-summary==0.1.2
pytorchtools==0.0.2
scikit-image==0.17.2
scikit-learn==0.24.2
scipy==1.5.4
seaborn==0.11.1
torch==1.8.1
torchmetrics==0.2.0
torchsummary==1.5.1
torchvision==0.9.1
tqdm==4.60.0
transformers==4.5.1

+ 111
- 0
test_main.py View File

@@ -0,0 +1,111 @@
import random

import numpy as np
import torch
from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca
from learner import batch_constructor
from model import FakeNewsModel


def test(config, test_loader, trial_number=None):
if trial_number:
try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
else:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))

try:
parameters = checkpoint['parameters']
config.assign_hyperparameters(parameters)
except:
pass

model = FakeNewsModel(config).to(config.device)
try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
model.load_state_dict(checkpoint)

model.eval()

torch.manual_seed(27)
random.seed(27)
np.random.seed(27)

image_features = []
text_features = []
multimodal_features = []
concat_features = []

targets = []
predictions = []
scores = []
tqdm_object = tqdm(test_loader, total=len(test_loader))
for i, batch in enumerate(tqdm_object):
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)

prediction = output.detach()
predictions.append(prediction)

score = score.detach()
scores.append(score)

target = batch['label'].detach()
targets.append(target)

image_feature = model.image_embeddings.detach()
image_features.append(image_feature)

text_feature = model.text_embeddings.detach()
text_features.append(text_feature)

multimodal_feature = model.multimodal_embeddings.detach()
multimodal_features.append(multimodal_feature)

concat_feature = model.classifier.embeddings.detach()
concat_features.append(concat_feature)

# config.writer.add_graph(model, input_to_model=batch, verbose=True)

s = ''
s += report_per_class(targets, predictions) + '\n'
s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n'
with open(config.output_path + '/results.txt', 'w') as f:
f.write(s)

roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")

# saving_in_tensorboard(config, image_features, targets, 'image_features')
plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png')
plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png')

# saving_in_tensorboard(config, text_features, targets, 'text_features')
plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png')
plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png')
#
# saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features')
plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png')
plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png')

# saving_in_tensorboard(config, concat_features, targets, 'concat_features')
plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png')
plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png')

config_parameters = str(config)
with open(config.output_path + '/parameters.txt', 'w') as f:
f.write(config_parameters)
print(config)


def test_main(config, trial_number=None):
train_df, test_df, validation_df, _ = make_dfs(config, )
test_loader = build_loaders(config, test_df, mode="test")
test(config, test_loader, trial_number)

+ 47
- 0
text.py View File

@@ -0,0 +1,47 @@
from torch import nn
from transformers import BertModel, BertConfig, BertTokenizer, \
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \
XLNetModel, XLNetConfig, XLNetTokenizer


class TextEncoder(nn.Module):
def __init__(self, config):
super().__init__()

self.config = config
self.model = get_text_model(config)
for p in self.model.parameters():
p.requires_grad = self.config.trainable
self.target_token_idx = 0
self.text_encoder_embedding = dict()

def forward(self, ids, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]

# for i, id in enumerate(ids):
# id = int(id.detach().cpu().numpy())
# self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
return last_hidden_state


def get_text_model(config):
if 'bigbird' in config.text_encoder_model:
if config.pretrained:
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2)
else:
model = BigBirdModel(config=BigBirdConfig())
elif 'xlnet' in config.text_encoder_model:
if config.pretrained:
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
else:
model = XLNetModel(config=XLNetConfig())
else:
if config.pretrained:
model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
else:
model = BertModel(config=BertConfig())
return model


+ 19
- 0
torch_main.py View File

@@ -0,0 +1,19 @@

from data_loaders import build_loaders, make_dfs
from learner import supervised_train
from test_main import test


def torch_main(config):
train_df, test_df, validation_df, _ = make_dfs(config, )
train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test")

supervised_train(config, train_loader, validation_loader)
test(config, test_loader)






+ 93
- 0
utils.py View File

@@ -0,0 +1,93 @@
import numpy as np
import torch


class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()

def reset(self):
self.avg, self.sum, self.count = [0] * 3

def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count

def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text


def print_lr(optimizer):
for param_group in optimizer.param_groups:
print(param_group['name'], param_group['lr'])


class CheckpointSaving:

def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print):
self.best_score = None
self.val_acc_max = 0
self.path = path
self.verbose = verbose
self.trace_func = trace_func

def __call__(self, val_acc, model):
if self.best_score is None:
self.best_score = val_acc
self.save_checkpoint(val_acc, model)
elif val_acc > self.best_score:
self.best_score = val_acc
self.save_checkpoint(val_acc, model)

def save_checkpoint(self, val_acc, model):
if self.verbose:
self.trace_func(
f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...')
torch.save(model.state_dict(), self.path)
self.val_acc_max = val_acc


class EarlyStopping:

def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.trace_func = trace_func

def __call__(self, val_loss, model):

score = -val_loss

if self.best_score is None:
self.best_score = score
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
self.val_loss_min = val_loss
# self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
# self.save_checkpoint(val_loss, model)
self.val_loss_min = val_loss
self.counter = 0

# def save_checkpoint(self, val_loss, model):
# if self.verbose:
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
# torch.save(model.state_dict(), self.path)
# self.val_loss_min = val_loss

Loading…
Cancel
Save