| @@ -114,3 +114,6 @@ dmypy.json | |||
| # Pyre type checker | |||
| .pyre/ | |||
| .idea/ | |||
| env/ | |||
| venv/ | |||
| @@ -1,3 +1,48 @@ | |||
| # FakeNewsRevealer | |||
| Official implementation of the Fake News Revealer paper | |||
| Official implementation of the "Fake News Revealer (FNR): A Similarity and Transformer-Based Approach to Detect | |||
| Multi-Modal Fake News in Social Media" paper [(ArXiv Link)](https://arxiv.org/pdf/2112.01131.pdf). | |||
| ## Requirments | |||
| FNR is built in Python 3.6 using PyTorch 1.8. Please use the following command to install the requirements: | |||
| ``` | |||
| pip install -r requirements.txt | |||
| ``` | |||
| ## How to Run | |||
| First, place the data address and configuration into the config file in the data directory, and then follow the train | |||
| and test commands. | |||
| ### train | |||
| To run with Optuna for parameter tuning use this command: | |||
| ``` | |||
| python main --data "DATA NAME" --use_optuna "NUMBER OF OPTUNA TRIALS" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||
| ``` | |||
| To run without parameter tuning, adjust your parameters in the config file and then use the below command: | |||
| ``` | |||
| python main --data "DATA NAME" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||
| ``` | |||
| ### test | |||
| In the test step, at first, make sure to have the requested 'checkpoint' file then run the following line: | |||
| ``` | |||
| python main --data "DATA NAME" --just_test "REQESTED TRIAL NUMBER" | |||
| ``` | |||
| ## Bibtex | |||
| Cite our paper using the following bibtex item: | |||
| ``` | |||
| @misc{ghorbanpour2021fnr, | |||
| title={FNR: A Similarity and Transformer-Based Approach to Detect Multi-Modal Fake News in Social Media}, | |||
| author={Faeze Ghorbanpour and Maryam Ramezani and Mohammad A. Fazli and Hamid R. Rabiee}, | |||
| year={2021}, | |||
| eprint={2112.01131}, | |||
| archivePrefix={arXiv}, | |||
| } | |||
| ``` | |||
| @@ -0,0 +1,117 @@ | |||
| import torch | |||
| import ast | |||
| class Config: | |||
| name = '' | |||
| DatasetLoader = None | |||
| debug = False | |||
| data_path = '' | |||
| output_path = '' | |||
| train_image_path = '' | |||
| validation_image_path = '' | |||
| test_image_path = '' | |||
| train_text_path = '' | |||
| validation_text_path = '' | |||
| test_text_path = '' | |||
| train_text_embedding_path = '' | |||
| validation_text_embedding_path = '' | |||
| train_image_embedding_path = '' | |||
| validation_image_embedding_path = '' | |||
| batch_size = 256 | |||
| epochs = 50 | |||
| num_workers = 2 | |||
| head_lr = 1e-03 | |||
| image_encoder_lr = 1e-04 | |||
| text_encoder_lr = 1e-04 | |||
| attention_lr = 1e-3 | |||
| classification_lr = 1e-03 | |||
| max_grad_norm = 5.0 | |||
| head_weight_decay = 0.001 | |||
| attention_weight_decay = 0.001 | |||
| classification_weight_decay = 0.001 | |||
| patience = 30 | |||
| delta = 0.0000001 | |||
| factor = 0.8 | |||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
| image_embedding = 768 | |||
| num_img_region = 64 # 16 #TODO | |||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| text_embedding = 768 | |||
| max_length = 32 | |||
| pretrained = True # for both image encoder and text encoder | |||
| trainable = False # for both image encoder and text encoder | |||
| temperature = 1.0 | |||
| # image size | |||
| size = 224 | |||
| num_projection_layers = 1 | |||
| projection_size = 256 | |||
| dropout = 0.3 | |||
| hidden_size = 256 | |||
| num_region = 64 # 16 #TODO | |||
| region_size = projection_size // num_region | |||
| classes = ['real', 'fake'] | |||
| class_num = 2 | |||
| loss_weight = 1 | |||
| class_weights = [1, 1] | |||
| writer = None | |||
| has_unlabeled_data = False | |||
| step = 0 | |||
| T1 = 10 | |||
| T2 = 150 | |||
| af = 3 | |||
| wanted_accuracy = 0.76 | |||
| def optuna(self, trial): | |||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
| self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1) | |||
| self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
| self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
| self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
| def __str__(self): | |||
| s = '' | |||
| members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] | |||
| for member in members: | |||
| s += member + '\t' + str(getattr(self, member)) + '\n' | |||
| return s | |||
| def assign_hyperparameters(self, s): | |||
| for line in s.split('\n'): | |||
| s = line.split('\t') | |||
| try: | |||
| attr = getattr(self, s[0]) | |||
| if type(attr) not in [list, set, dict]: | |||
| setattr(self, s[0], type(attr)(s[1])) | |||
| else: | |||
| setattr(self, s[0], ast.literal_eval(s[1])) | |||
| except: | |||
| print(s[0], 'doesnot exist') | |||
| @@ -0,0 +1,33 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| class DatasetLoader(Dataset): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer( | |||
| list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||
| ) | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| def set_text(self, idx): | |||
| item = { | |||
| key: values[idx].clone().detach() | |||
| for key, values in self.encoded_text.items() | |||
| } | |||
| item['text'] = self.text[idx] | |||
| item['label'] = self.labels[idx] | |||
| item['id'] = idx | |||
| return item | |||
| def set_image(self, image, item): | |||
| if 'resnet' in self.config.image_model_name: | |||
| image = self.transforms(image=image)['image'] | |||
| item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||
| else: | |||
| image = self.transforms(images=image, return_tensors='pt') | |||
| image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||
| item['image'] = image.reshape((3, 224, 224)) | |||
| @@ -0,0 +1,65 @@ | |||
| import torch | |||
| from data.config import Config | |||
| from data.twitter.data_loader import TwitterDatasetLoader | |||
| class TwitterConfig(Config): | |||
| name = 'twitter' | |||
| DatasetLoader = TwitterDatasetLoader | |||
| data_path = '/twitter/' | |||
| output_path = '' | |||
| train_image_path = data_path + 'images_train/' | |||
| validation_image_path = data_path + 'images_validation/' | |||
| test_image_path = data_path + 'images_test/' | |||
| train_text_path = data_path + 'twitter_train_translated.csv' | |||
| validation_text_path = data_path + 'twitter_validation_translated.csv' | |||
| test_text_path = data_path + 'twitter_test_translated.csv' | |||
| batch_size = 128 | |||
| epochs = 100 | |||
| num_workers = 2 | |||
| head_lr = 1e-03 | |||
| image_encoder_lr = 1e-04 | |||
| text_encoder_lr = 1e-04 | |||
| attention_lr = 1e-3 | |||
| classification_lr = 1e-03 | |||
| head_weight_decay = 0.001 | |||
| attention_weight_decay = 0.001 | |||
| classification_weight_decay = 0.001 | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| image_model_name = 'vit-base-patch16-224' | |||
| image_embedding = 768 | |||
| text_encoder_model = "bert-base-uncased" | |||
| text_tokenizer = "bert-base-uncased" | |||
| text_embedding = 768 | |||
| max_length = 32 | |||
| pretrained = True | |||
| trainable = False | |||
| temperature = 1.0 | |||
| classes = ['real', 'fake'] | |||
| class_weights = [1, 1] | |||
| wanted_accuracy = 0.76 | |||
| def optuna(self, trial): | |||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
| # self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
| @@ -0,0 +1,65 @@ | |||
| import cv2 | |||
| import pickle | |||
| from PIL import Image | |||
| from numpy import asarray | |||
| from data.data_loader import DatasetLoader | |||
| class TwitterDatasetLoader(DatasetLoader): | |||
| def __getitem__(self, idx): | |||
| item = self.set_text(idx) | |||
| try: | |||
| if self.mode == 'train': | |||
| image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||
| elif self.mode == 'validation': | |||
| image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}") | |||
| else: | |||
| image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}") | |||
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||
| except: | |||
| image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
| image = asarray(image) | |||
| self.set_image(image, item) | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| if mode == 'train': | |||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| else: | |||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| def __getitem__(self, idx): | |||
| item = dict() | |||
| item['id'] = idx | |||
| item['image_embedding'] = self.image_embedding[idx] | |||
| item['text_embedding'] = self.text_embedding[idx] | |||
| item['label'] = self.labels[idx] | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| @@ -0,0 +1,62 @@ | |||
| import torch | |||
| from transformers import BertTokenizer, BertModel, BertConfig | |||
| from data.config import Config | |||
| from data.weibo.data_loader import WeiboDatasetLoader | |||
| class WeiboConfig(Config): | |||
| name = 'weibo' | |||
| DatasetLoader = WeiboDatasetLoader | |||
| data_path = 'weibo/' | |||
| output_path = '' | |||
| rumor_image_path = data_path + 'rumor_images/' | |||
| nonrumor_image_path = data_path + 'nonrumor_images/' | |||
| train_text_path = data_path + 'weibo_train.csv' | |||
| validation_text_path = data_path + 'weibo_validation.csv' | |||
| test_text_path = data_path + 'weibo_test.csv' | |||
| batch_size = 128 | |||
| epochs = 100 | |||
| num_workers = 2 | |||
| head_lr = 1e-03 | |||
| image_encoder_lr = 1e-02 | |||
| text_encoder_lr = 1e-05 | |||
| weight_decay = 0.001 | |||
| classification_lr = 1e-02 | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||
| image_embedding = 768 # 2048 | |||
| num_img_region = 64 # TODO | |||
| text_encoder_model = "bert-base-chinese" | |||
| text_tokenizer = "bert-base-chinese" | |||
| text_embedding = 768 | |||
| max_length = 200 | |||
| pretrained = True | |||
| trainable = False | |||
| temperature = 1.0 | |||
| labels = ['real', 'fake'] | |||
| wanted_accuracy = 0.80 | |||
| def optuna(self, trial): | |||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
| @@ -0,0 +1,66 @@ | |||
| import pickle | |||
| import cv2 | |||
| from PIL import Image | |||
| from numpy import asarray | |||
| from data.data_loader import DatasetLoader | |||
| class WeiboDatasetLoader(DatasetLoader): | |||
| def __getitem__(self, idx): | |||
| item = self.set_text(idx) | |||
| if self.labels[idx] == 1: | |||
| try: | |||
| image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||
| except: | |||
| image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
| image = asarray(image) | |||
| else: | |||
| try: | |||
| image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}") | |||
| except: | |||
| image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
| image = asarray(image) | |||
| self.set_image(image, item) | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| if mode == 'train': | |||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| else: | |||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| def __getitem__(self, idx): | |||
| item = dict() | |||
| item['id'] = idx | |||
| item['image_embedding'] = self.image_embedding[idx] | |||
| item['text_embedding'] = self.text_embedding[idx] | |||
| item['label'] = self.labels[idx] | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| @@ -0,0 +1,108 @@ | |||
| import albumentations as A | |||
| import numpy as np | |||
| import pandas as pd | |||
| from sklearn.utils import compute_class_weight | |||
| from torch.utils.data import DataLoader | |||
| from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
| def get_transforms(config, mode="train"): | |||
| if mode == "train": | |||
| return A.Compose( | |||
| [ | |||
| A.Resize(config.size, config.size, always_apply=True), | |||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
| ] | |||
| ) | |||
| else: | |||
| return A.Compose( | |||
| [ | |||
| A.Resize(config.size, config.size, always_apply=True), | |||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
| ] | |||
| ) | |||
| def make_dfs(config): | |||
| train_dataframe = pd.read_csv(config.train_text_path) | |||
| train_dataframe.dropna(subset=['text'], inplace=True) | |||
| train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True) | |||
| train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
| config.class_weights = get_class_weights(train_dataframe.label.values) | |||
| if config.test_text_path is None: | |||
| offset = int(train_dataframe.shape[0] * 0.80) | |||
| test_dataframe = train_dataframe[offset:] | |||
| train_dataframe = train_dataframe[:offset] | |||
| else: | |||
| test_dataframe = pd.read_csv(config.test_text_path) | |||
| test_dataframe.dropna(subset=['text'], inplace=True) | |||
| test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||
| test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
| if config.validation_text_path is None: | |||
| offset = int(train_dataframe.shape[0] * 0.90) | |||
| validation_dataframe = train_dataframe[offset:] | |||
| train_dataframe = train_dataframe[:offset] | |||
| else: | |||
| validation_dataframe = pd.read_csv(config.validation_text_path) | |||
| validation_dataframe.dropna(subset=['text'], inplace=True) | |||
| validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||
| validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
| unlabeled_dataframe = None | |||
| if config.has_unlabeled_data: | |||
| unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||
| unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||
| return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||
| def get_tokenizer(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
| else: | |||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
| return tokenizer | |||
| def build_loaders(config, dataframe, mode): | |||
| if 'resnet' in config.image_model_name: | |||
| transforms = get_transforms(config, mode=mode) | |||
| else: | |||
| transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
| dataset = config.DatasetLoader( | |||
| config, | |||
| dataframe["image"].values, | |||
| dataframe["text"].values, | |||
| dataframe["label"].values, | |||
| tokenizer=get_tokenizer(config), | |||
| transforms=transforms, | |||
| mode=mode, | |||
| ) | |||
| if mode != 'train': | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| batch_size=config.batch_size, | |||
| num_workers=config.num_workers, | |||
| pin_memory=False, | |||
| shuffle=False, | |||
| ) | |||
| else: | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| batch_size=config.batch_size, | |||
| num_workers=config.num_workers, | |||
| pin_memory=True, | |||
| shuffle=True, | |||
| ) | |||
| return dataloader | |||
| def get_class_weights(y): | |||
| class_weights = compute_class_weight('balanced', np.unique(y), y) | |||
| return class_weights | |||
| @@ -0,0 +1,257 @@ | |||
| from itertools import cycle | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import pandas as pd | |||
| import seaborn as sns | |||
| from numpy import interp | |||
| from sklearn.decomposition import PCA | |||
| from sklearn.manifold import TSNE | |||
| from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score | |||
| from sklearn.metrics import classification_report, roc_curve, auc | |||
| from sklearn.preprocessing import OneHotEncoder | |||
| def metrics(truth, pred, prob, file_path): | |||
| truth = [i.cpu().numpy() for i in truth] | |||
| pred = [i.cpu().numpy() for i in pred] | |||
| prob = [i.cpu().numpy() for i in prob] | |||
| pred = np.concatenate(pred, axis=0) | |||
| truth = np.concatenate(truth, axis=0) | |||
| prob = np.concatenate(prob, axis=0) | |||
| prob = prob[:, 1] | |||
| f_score_micro = f1_score(truth, pred, average='micro', zero_division=0) | |||
| f_score_macro = f1_score(truth, pred, average='macro', zero_division=0) | |||
| f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0) | |||
| accuarcy = accuracy_score(truth, pred) | |||
| s = '' | |||
| print('accuracy', accuarcy) | |||
| s += 'accuracy' + str(accuarcy) + '\n' | |||
| print('f_score_micro', f_score_micro) | |||
| s += 'f_score_micro' + str(f_score_micro) + '\n' | |||
| print('f_score_macro', f_score_macro) | |||
| s += 'f_score_macro' + str(f_score_macro) + '\n' | |||
| print('f_score_weighted', f_score_weighted) | |||
| s += 'f_score_weighted' + str(f_score_weighted) + '\n' | |||
| fpr, tpr, thresholds = roc_curve(truth, prob) | |||
| AUC = auc(fpr, tpr) | |||
| print('AUC', AUC) | |||
| s += 'AUC' + str(AUC) + '\n' | |||
| df = pd.DataFrame(dict(fpr=fpr, tpr=tpr)) | |||
| df.to_csv(file_path) | |||
| return s | |||
| def report_per_class(truth, pred): | |||
| truth = [i.cpu().numpy() for i in truth] | |||
| pred = [i.cpu().numpy() for i in pred] | |||
| pred = np.concatenate(pred, axis=0) | |||
| truth = np.concatenate(truth, axis=0) | |||
| report = classification_report(truth, pred, zero_division=0, output_dict=True) | |||
| s = '' | |||
| class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']] | |||
| for class_label in class_labels: | |||
| print('class_label', class_label) | |||
| s += 'class_label' + str(class_label) + '\n' | |||
| s += str(report[class_label]) | |||
| print(report[class_label]) | |||
| return s | |||
| def multiclass_acc(truth, pred): | |||
| truth = [i.cpu().numpy() for i in truth] | |||
| pred = [i.cpu().numpy() for i in pred] | |||
| pred = np.concatenate(pred, axis=0) | |||
| truth = np.concatenate(truth, axis=0) | |||
| return accuracy_score(truth, pred) | |||
| def roc_auc_plot(truth, score, num_class=2, fname='roc.png'): | |||
| truth = [i.cpu().numpy() for i in truth] | |||
| score = [i.cpu().numpy() for i in score] | |||
| truth = np.concatenate(truth, axis=0) | |||
| score = np.concatenate(score, axis=0) | |||
| enc = OneHotEncoder(handle_unknown='ignore') | |||
| enc.fit(truth.reshape(-1, 1)) | |||
| label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||
| fpr_dict = dict() | |||
| tpr_dict = dict() | |||
| roc_auc_dict = dict() | |||
| for i in range(num_class): | |||
| fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i]) | |||
| roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) | |||
| # micro | |||
| fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel()) | |||
| roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) | |||
| # macro | |||
| all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)])) | |||
| mean_tpr = np.zeros_like(all_fpr) | |||
| for i in range(num_class): | |||
| mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) | |||
| mean_tpr /= num_class | |||
| fpr_dict["macro"] = all_fpr | |||
| tpr_dict["macro"] = mean_tpr | |||
| roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) | |||
| plt.figure() | |||
| lw = 2 | |||
| plt.plot(fpr_dict["micro"], tpr_dict["micro"], | |||
| label='micro-average ROC curve (area = {0:0.2f})' | |||
| ''.format(roc_auc_dict["micro"]), | |||
| color='deeppink', linestyle=':', linewidth=4) | |||
| plt.plot(fpr_dict["macro"], tpr_dict["macro"], | |||
| label='macro-average ROC curve (area = {0:0.2f})' | |||
| ''.format(roc_auc_dict["macro"]), | |||
| color='navy', linestyle=':', linewidth=4) | |||
| colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||
| for i, color in zip(range(num_class), colors): | |||
| plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, | |||
| label='ROC curve of class {0} (area = {1:0.2f})' | |||
| ''.format(i, roc_auc_dict[i])) | |||
| plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||
| plt.xlim([0.0, 1.0]) | |||
| plt.ylim([0.0, 1.05]) | |||
| plt.xlabel('False Positive Rate') | |||
| plt.ylabel('True Positive Rate') | |||
| plt.legend(loc="lower right") | |||
| plt.savefig(fname) | |||
| # plt.show() | |||
| def precision_recall_plot(truth, score, num_class=2, fname='pr.png'): | |||
| truth = [i.cpu().numpy() for i in truth] | |||
| score = [i.cpu().numpy() for i in score] | |||
| truth = np.concatenate(truth, axis=0) | |||
| score = np.concatenate(score, axis=0) | |||
| enc = OneHotEncoder(handle_unknown='ignore') | |||
| enc.fit(truth.reshape(-1, 1)) | |||
| label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||
| # Call the Sklearn library, calculate the precision and recall corresponding to each category | |||
| precision_dict = dict() | |||
| recall_dict = dict() | |||
| average_precision_dict = dict() | |||
| for i in range(num_class): | |||
| precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i]) | |||
| average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i]) | |||
| print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i]) | |||
| # micro | |||
| precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(), | |||
| score.ravel()) | |||
| average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro") | |||
| # macro | |||
| all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)])) | |||
| mean_tpr = np.zeros_like(all_fpr) | |||
| for i in range(num_class): | |||
| mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i]) | |||
| mean_tpr /= num_class | |||
| precision_dict["macro"] = all_fpr | |||
| recall_dict["macro"] = mean_tpr | |||
| average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"]) | |||
| plt.figure() | |||
| plt.subplots(figsize=(16, 10)) | |||
| lw = 2 | |||
| plt.plot(precision_dict["micro"], recall_dict["micro"], | |||
| label='micro-average Precision-Recall curve (area = {0:0.2f})' | |||
| ''.format(average_precision_dict["micro"]), | |||
| color='deeppink', linestyle=':', linewidth=4) | |||
| plt.plot(precision_dict["macro"], recall_dict["macro"], | |||
| label='macro-average Precision-Recall curve (area = {0:0.2f})' | |||
| ''.format(average_precision_dict["macro"]), | |||
| color='navy', linestyle=':', linewidth=4) | |||
| colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||
| for i, color in zip(range(num_class), colors): | |||
| plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw, | |||
| label='Precision-Recall curve of class {0} (area = {1:0.2f})' | |||
| ''.format(i, average_precision_dict[i])) | |||
| plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||
| plt.xlabel('Recall') | |||
| plt.ylabel('Precision') | |||
| plt.ylim([0.0, 1.05]) | |||
| plt.xlim([0.0, 1.0]) | |||
| plt.legend(loc="lower left") | |||
| plt.savefig(fname=fname) | |||
| # plt.show() | |||
| def saving_in_tensorboard(config, x, y, fname='embedding'): | |||
| x = [i.cpu().numpy() for i in x] | |||
| y = [i.cpu().numpy() for i in y] | |||
| x = np.concatenate(x, axis=0) | |||
| y = np.concatenate(y, axis=0) | |||
| z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
| # config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname) | |||
| def plot_tsne(config, x, y, fname='tsne.png'): | |||
| x = [i.cpu().numpy() for i in x] | |||
| y = [i.cpu().numpy() for i in y] | |||
| x = np.concatenate(x, axis=0) | |||
| y = np.concatenate(y, axis=0) | |||
| y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
| tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000) | |||
| tsne_proj = tsne.fit_transform(x) | |||
| fig, ax = plt.subplots(figsize=(16, 10)) | |||
| palette = sns.color_palette("bright", 2) | |||
| sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette) | |||
| ax.legend(fontsize='large', markerscale=2) | |||
| plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0])) | |||
| plt.savefig(fname=fname) | |||
| # plt.show() | |||
| def plot_pca(config, x, y, fname='pca.png'): | |||
| x = [i.cpu().numpy() for i in x] | |||
| y = [i.cpu().numpy() for i in y] | |||
| x = np.concatenate(x, axis=0) | |||
| y = np.concatenate(y, axis=0) | |||
| y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||
| pca = PCA(n_components=2) | |||
| pca_proj = pca.fit_transform(x) | |||
| fig, ax = plt.subplots(figsize=(16, 10)) | |||
| palette = sns.color_palette("bright", 2) | |||
| sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette) | |||
| ax.legend(fontsize='large', markerscale=2) | |||
| plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0])) | |||
| plt.savefig(fname=fname) | |||
| # plt.show() | |||
| @@ -0,0 +1,45 @@ | |||
| import timm | |||
| from torch import nn | |||
| from transformers import ViTModel, ViTConfig | |||
| class ImageTransformerEncoder(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| if config.pretrained: | |||
| self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False, | |||
| output_hidden_states=True, return_dict=True) | |||
| else: | |||
| self.model = ViTModel(config=ViTConfig()) | |||
| for p in self.model.parameters(): | |||
| p.requires_grad = config.trainable | |||
| self.target_token_idx = 0 | |||
| self.image_encoder_embedding = dict() | |||
| def forward(self, ids, image): | |||
| output = self.model(image) | |||
| last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||
| # for i, id in enumerate(ids): | |||
| # id = int(id.detach().cpu().numpy()) | |||
| # self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||
| return last_hidden_state | |||
| class ImageResnetEncoder(nn.Module): | |||
| def __init__( | |||
| self, config | |||
| ): | |||
| super().__init__() | |||
| self.model = timm.create_model( | |||
| config.image_model_name, config.pretrained, num_classes=0, global_pool="avg" | |||
| ) | |||
| for p in self.model.parameters(): | |||
| p.requires_grad = config.trainable | |||
| def forward(self, ids, image): | |||
| return self.model(image) | |||
| @@ -0,0 +1,206 @@ | |||
| import gc | |||
| import itertools | |||
| import random | |||
| import numpy as np | |||
| import optuna | |||
| import pandas as pd | |||
| import torch | |||
| from evaluation import multiclass_acc | |||
| from model import FakeNewsModel, calculate_loss | |||
| from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving | |||
| def batch_constructor(config, batch): | |||
| b = {} | |||
| for k, v in batch.items(): | |||
| if k != 'text': | |||
| b[k] = v.to(config.device) | |||
| return b | |||
| def train_epoch(config, model, train_loader, optimizer, scalar): | |||
| loss_meter = AvgMeter('train') | |||
| # tqdm_object = tqdm(train_loader, total=len(train_loader)) | |||
| targets = [] | |||
| predictions = [] | |||
| for index, batch in enumerate(train_loader): | |||
| batch = batch_constructor(config, batch) | |||
| optimizer.zero_grad(set_to_none=True) | |||
| with torch.cuda.amp.autocast(): | |||
| output, score = model(batch) | |||
| loss = calculate_loss(model, score, batch['label']) | |||
| scalar.scale(loss).backward() | |||
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |||
| if (index + 1) % 2: | |||
| scalar.step(optimizer) | |||
| # loss.backward() | |||
| # optimizer.step() | |||
| scalar.update() | |||
| count = batch["id"].size(0) | |||
| loss_meter.update(loss.detach(), count) | |||
| # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |||
| prediction = output.detach() | |||
| predictions.append(prediction) | |||
| target = batch['label'].detach() | |||
| targets.append(target) | |||
| return loss_meter, targets, predictions | |||
| def validation_epoch(config, model, validation_loader): | |||
| loss_meter = AvgMeter('validation') | |||
| targets = [] | |||
| predictions = [] | |||
| # tqdm_object = tqdm(validation_loader, total=len(validation_loader)) | |||
| for batch in validation_loader: | |||
| batch = batch_constructor(config, batch) | |||
| with torch.no_grad(): | |||
| output, score = model(batch) | |||
| loss = calculate_loss(model, score, batch['label']) | |||
| count = batch["id"].size(0) | |||
| loss_meter.update(loss.detach(), count) | |||
| # tqdm_object.set_postfix(validation_loss=loss_meter.avg) | |||
| prediction = output.detach() | |||
| predictions.append(prediction) | |||
| target = batch['label'].detach() | |||
| targets.append(target) | |||
| return loss_meter, targets, predictions | |||
| def supervised_train(config, train_loader, validation_loader, trial=None): | |||
| torch.cuda.empty_cache() | |||
| checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt' | |||
| if trial: | |||
| checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt' | |||
| torch.manual_seed(27) | |||
| random.seed(27) | |||
| np.random.seed(27) | |||
| torch.autograd.set_detect_anomaly(False) | |||
| torch.autograd.profiler.profile(False) | |||
| torch.autograd.profiler.emit_nvtx(False) | |||
| scalar = torch.cuda.amp.GradScaler() | |||
| model = FakeNewsModel(config).to(config.device) | |||
| params = [ | |||
| {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'}, | |||
| {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'}, | |||
| {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), | |||
| "lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'}, | |||
| {"params": model.classifier.parameters(), "lr": config.classification_lr, | |||
| "weight_decay": config.classification_weight_decay, | |||
| 'name': 'classifier'} | |||
| ] | |||
| optimizer = torch.optim.AdamW(params, amsgrad=True) | |||
| lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor, | |||
| patience=config.patience // 5, verbose=True) | |||
| early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True) | |||
| checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True) | |||
| train_losses, train_accuracies = [], [] | |||
| validation_losses, validation_accuracies = [], [] | |||
| validation_accuracy, validation_loss = 0, 1 | |||
| for epoch in range(config.epochs): | |||
| print(f"Epoch: {epoch + 1}") | |||
| gc.collect() | |||
| model.train() | |||
| train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar) | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||
| train_accuracy = multiclass_acc(train_truth, train_pred) | |||
| validation_accuracy = multiclass_acc(validation_truth, validation_pred) | |||
| print_lr(optimizer) | |||
| print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy) | |||
| print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) | |||
| if lr_scheduler: | |||
| lr_scheduler.step(validation_loss.avg) | |||
| if early_stopping: | |||
| early_stopping(validation_loss.avg, model) | |||
| if early_stopping.early_stop: | |||
| print("Early stopping") | |||
| break | |||
| if checkpoint_saving: | |||
| checkpoint_saving(validation_accuracy, model) | |||
| train_accuracies.append(train_accuracy) | |||
| train_losses.append(train_loss) | |||
| validation_accuracies.append(validation_accuracy) | |||
| validation_losses.append(validation_loss) | |||
| if trial: | |||
| trial.report(validation_accuracy, epoch) | |||
| if trial.should_prune(): | |||
| print('trial pruned') | |||
| raise optuna.exceptions.TrialPruned() | |||
| print() | |||
| if checkpoint_saving: | |||
| model = FakeNewsModel(config).to(config.device) | |||
| model.load_state_dict(torch.load(checkpoint_path)) | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||
| validation_accuracy = multiclass_acc(validation_pred, validation_truth) | |||
| if trial and validation_accuracy >= config.wanted_accuracy: | |||
| loss_accuracy = pd.DataFrame( | |||
| {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||
| 'validation_accuracy': validation_accuracies}) | |||
| torch.save({'model_state_dict': model.state_dict(), | |||
| 'parameters': str(config), | |||
| 'optimizer_state_dict': optimizer.state_dict(), | |||
| 'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||
| if not checkpoint_saving: | |||
| loss_accuracy = pd.DataFrame( | |||
| {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||
| 'validation_accuracy': validation_accuracies}) | |||
| torch.save(model.state_dict(), checkpoint_path) | |||
| if trial and validation_accuracy >= config.wanted_accuracy: | |||
| torch.save({'model_state_dict': model.state_dict(), | |||
| 'parameters': str(config), | |||
| 'optimizer_state_dict': optimizer.state_dict(), | |||
| 'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||
| try: | |||
| del train_loss | |||
| del train_truth | |||
| del train_pred | |||
| del validation_loss | |||
| del validation_truth | |||
| del validation_pred | |||
| del train_losses | |||
| del train_accuracies | |||
| del validation_losses | |||
| del validation_accuracies | |||
| del loss_accuracy | |||
| del scalar | |||
| del model | |||
| del params | |||
| except: | |||
| print('Error in deleting caches') | |||
| pass | |||
| return validation_accuracy | |||
| @@ -0,0 +1,63 @@ | |||
| import argparse | |||
| from optuna_main import optuna_main | |||
| from test_main import test_main | |||
| from torch_main import torch_main | |||
| from data.twitter.config import TwitterConfig | |||
| from data.weibo.config import WeiboConfig | |||
| from torch.utils.tensorboard import SummaryWriter | |||
| if __name__ == '__main__': | |||
| import os | |||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--data', type=str, required=True) | |||
| parser.add_argument('--use_optuna', type=int, required=False) | |||
| parser.add_argument('--just_test', type=int, required=False) | |||
| parser.add_argument('--batch', type=int, required=False) | |||
| parser.add_argument('--epoch', type=int, required=False) | |||
| parser.add_argument('--extra', type=str, required=False) | |||
| args = parser.parse_args() | |||
| if args.data == 'twitter': | |||
| config = TwitterConfig() | |||
| elif args.data == 'weibo': | |||
| config = WeiboConfig() | |||
| else: | |||
| raise Exception('Enter a valid dataset name', args.data) | |||
| if args.batch: | |||
| config.batch_size = args.batch | |||
| if args.epoch: | |||
| config.epochs = args.epoch | |||
| if args.use_optuna: | |||
| config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra) | |||
| else: | |||
| config.output_path += 'logs/' + args.data + '_' + str(args.extra) | |||
| use_optuna = True | |||
| if not args.extra or 'temp' in args.extra: | |||
| config.output_path = str(args.extra) | |||
| use_optuna = False | |||
| try: | |||
| os.mkdir(config.output_path) | |||
| except OSError: | |||
| print("Creation of the directory failed") | |||
| else: | |||
| print("Successfully created the directory") | |||
| config.writer = SummaryWriter(config.output_path) | |||
| if args.use_optuna and use_optuna: | |||
| optuna_main(config, args.use_optuna) | |||
| elif args.just_test: | |||
| test_main(config, args.just_test) | |||
| else: | |||
| torch_main(config) | |||
| config.writer.close() | |||
| @@ -0,0 +1,143 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from image import ImageTransformerEncoder, ImageResnetEncoder | |||
| from text import TextEncoder | |||
| class ProjectionHead(nn.Module): | |||
| def __init__( | |||
| self, | |||
| config, | |||
| embedding_dim, | |||
| ): | |||
| super().__init__() | |||
| self.config = config | |||
| self.projection = nn.Linear(embedding_dim, config.projection_size) | |||
| self.gelu = nn.GELU() | |||
| self.fc = nn.Linear(config.projection_size, config.projection_size) | |||
| self.dropout = nn.Dropout(config.dropout) | |||
| self.layer_norm = nn.LayerNorm(config.projection_size) | |||
| def forward(self, x): | |||
| projected = self.projection(x) | |||
| x = self.gelu(projected) | |||
| x = self.fc(x) | |||
| x = self.dropout(x) | |||
| x = x + projected | |||
| x = self.layer_norm(x) | |||
| return x | |||
| class Classifier(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size) | |||
| self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size) | |||
| self.gelu = nn.GELU() | |||
| self.drop_out = nn.Dropout(config.dropout) | |||
| self.layer_norm_2 = nn.LayerNorm(config.hidden_size) | |||
| self.classifier_layer = nn.Linear(config.hidden_size, config.class_num) | |||
| self.softmax = nn.Softmax(dim=1) | |||
| def forward(self, x): | |||
| x = self.layer_norm_1(x) | |||
| x = self.linear_layer(x) | |||
| x = self.gelu(x) | |||
| x = self.drop_out(x) | |||
| self.embeddings = x = self.layer_norm_2(x) | |||
| x = self.classifier_layer(x) | |||
| x = self.softmax(x) | |||
| return x | |||
| class FakeNewsModel(nn.Module): | |||
| def __init__( | |||
| self, config | |||
| ): | |||
| super().__init__() | |||
| self.config = config | |||
| if 'resnet' in self.config.image_model_name: | |||
| self.image_encoder = ImageResnetEncoder(config) | |||
| else: | |||
| self.image_encoder = ImageTransformerEncoder(config) | |||
| self.text_encoder = TextEncoder(config) | |||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
| self.classifier = Classifier(config) | |||
| self.temperature = config.temperature | |||
| class_weights = torch.FloatTensor(config.class_weights) | |||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
| self.text_embeddings = None | |||
| self.image_embeddings = None | |||
| self.multimodal_embeddings = None | |||
| def forward(self, batch): | |||
| image_features = self.image_encoder(ids=batch['id'], image=batch["image"]) | |||
| text_features = self.text_encoder(ids=batch['id'], | |||
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |||
| self.image_embeddings = self.image_projection(image_features) | |||
| self.text_embeddings = self.text_projection(text_features) | |||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
| score = self.classifier(self.multimodal_embeddings) | |||
| probs, output = torch.max(score.data, dim=1) | |||
| return output, score | |||
| class FakeNewsEmbeddingModel(nn.Module): | |||
| def __init__( | |||
| self, config | |||
| ): | |||
| super().__init__() | |||
| self.config = config | |||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
| self.classifier = Classifier(config) | |||
| self.temperature = config.temperature | |||
| class_weights = torch.FloatTensor(config.class_weights) | |||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
| def forward(self, batch): | |||
| self.image_embeddings = self.image_projection(batch['image_embedding']) | |||
| self.text_embeddings = self.text_projection(batch['text_embedding']) | |||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
| score = self.classifier(self.multimodal_embeddings) | |||
| probs, output = torch.max(score.data, dim=1) | |||
| return output, score | |||
| def calculate_loss(model, score, label): | |||
| similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||
| # fake = (label == 1).nonzero() | |||
| # real = (label == 0).nonzero() | |||
| # s_loss = 0 * similarity[fake].mean() + similarity[real].mean() | |||
| c_loss = model.classifier_loss_function(score, label) | |||
| loss = model.config.loss_weight * c_loss + similarity.mean() | |||
| return loss | |||
| def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||
| # Calculating the Loss | |||
| logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||
| images_similarity = image_embeddings @ image_embeddings.T | |||
| texts_similarity = text_embeddings @ text_embeddings.T | |||
| targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||
| texts_loss = cross_entropy(logits, targets, reduction='none') | |||
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||
| loss = (images_loss + texts_loss) / 2.0 | |||
| return loss | |||
| def cross_entropy(preds, targets, reduction='none'): | |||
| log_softmax = nn.LogSoftmax(dim=-1) | |||
| loss = (-targets * log_softmax(preds)).sum(1) | |||
| if reduction == "none": | |||
| return loss | |||
| elif reduction == "mean": | |||
| return loss.mean() | |||
| @@ -0,0 +1,62 @@ | |||
| import joblib | |||
| import optuna | |||
| from optuna.pruners import MedianPruner | |||
| from optuna.trial import TrialState | |||
| from data_loaders import make_dfs, build_loaders | |||
| from learner import supervised_train | |||
| from test_main import test | |||
| def objective(trial, config, train_loader, validation_loader): | |||
| config.optuna(trial=trial) | |||
| print('Trial', trial.number, 'parameters', trial.params) | |||
| accuracy = supervised_train(config, train_loader, validation_loader, trial=trial) | |||
| return accuracy | |||
| def optuna_main(config, n_trials=100): | |||
| train_df, test_df, validation_df, _ = make_dfs(config) | |||
| train_loader = build_loaders(config, train_df, mode="train") | |||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| study = optuna.create_study(study_name=config.output_path.split('/')[-1], | |||
| sampler=optuna.samplers.TPESampler(), | |||
| storage=f'sqlite:///{config.output_path + "/optuna.db"}', | |||
| load_if_exists=True, | |||
| direction="maximize", | |||
| pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10) | |||
| ) | |||
| study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials) | |||
| pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) | |||
| complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) | |||
| joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl') | |||
| print("Study statistics: ") | |||
| print(" Number of finished trials: ", len(study.trials)) | |||
| print(" Number of pruned trials: ", len(pruned_trials)) | |||
| print(" Number of complete trials: ", len(complete_trials)) | |||
| s = '' | |||
| print("Best trial:") | |||
| trial = study.best_trial | |||
| print(' Number: ', trial.number) | |||
| print(" Value: ", trial.value) | |||
| s += 'value: ' + str(trial.value) + '\n' | |||
| print(" Params: ") | |||
| s += 'params: \n' | |||
| for key, value in trial.params.items(): | |||
| print(" {}: {}".format(key, value)) | |||
| s += " {}: {}\n".format(key, value) | |||
| with open(config.output_path+'/optuna_results.txt', 'w') as f: | |||
| f.write(s) | |||
| test(config, test_loader, trial_number=trial.number) | |||
| @@ -0,0 +1,17 @@ | |||
| opencv-python==4.5.1.48 | |||
| optuna==2.8.0 | |||
| pandas==1.1.5 | |||
| Pillow==8.2.0 | |||
| pytorch-lightning==1.2.10 | |||
| pytorch-model-summary==0.1.2 | |||
| pytorchtools==0.0.2 | |||
| scikit-image==0.17.2 | |||
| scikit-learn==0.24.2 | |||
| scipy==1.5.4 | |||
| seaborn==0.11.1 | |||
| torch==1.8.1 | |||
| torchmetrics==0.2.0 | |||
| torchsummary==1.5.1 | |||
| torchvision==0.9.1 | |||
| tqdm==4.60.0 | |||
| transformers==4.5.1 | |||
| @@ -0,0 +1,111 @@ | |||
| import random | |||
| import numpy as np | |||
| import torch | |||
| from tqdm import tqdm | |||
| from data_loaders import make_dfs, build_loaders | |||
| from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca | |||
| from learner import batch_constructor | |||
| from model import FakeNewsModel | |||
| def test(config, test_loader, trial_number=None): | |||
| if trial_number: | |||
| try: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
| except: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
| else: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
| try: | |||
| parameters = checkpoint['parameters'] | |||
| config.assign_hyperparameters(parameters) | |||
| except: | |||
| pass | |||
| model = FakeNewsModel(config).to(config.device) | |||
| try: | |||
| model.load_state_dict(checkpoint['model_state_dict']) | |||
| except: | |||
| model.load_state_dict(checkpoint) | |||
| model.eval() | |||
| torch.manual_seed(27) | |||
| random.seed(27) | |||
| np.random.seed(27) | |||
| image_features = [] | |||
| text_features = [] | |||
| multimodal_features = [] | |||
| concat_features = [] | |||
| targets = [] | |||
| predictions = [] | |||
| scores = [] | |||
| tqdm_object = tqdm(test_loader, total=len(test_loader)) | |||
| for i, batch in enumerate(tqdm_object): | |||
| batch = batch_constructor(config, batch) | |||
| with torch.no_grad(): | |||
| output, score = model(batch) | |||
| prediction = output.detach() | |||
| predictions.append(prediction) | |||
| score = score.detach() | |||
| scores.append(score) | |||
| target = batch['label'].detach() | |||
| targets.append(target) | |||
| image_feature = model.image_embeddings.detach() | |||
| image_features.append(image_feature) | |||
| text_feature = model.text_embeddings.detach() | |||
| text_features.append(text_feature) | |||
| multimodal_feature = model.multimodal_embeddings.detach() | |||
| multimodal_features.append(multimodal_feature) | |||
| concat_feature = model.classifier.embeddings.detach() | |||
| concat_features.append(concat_feature) | |||
| # config.writer.add_graph(model, input_to_model=batch, verbose=True) | |||
| s = '' | |||
| s += report_per_class(targets, predictions) + '\n' | |||
| s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||
| with open(config.output_path + '/results.txt', 'w') as f: | |||
| f.write(s) | |||
| roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||
| precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||
| # saving_in_tensorboard(config, image_features, targets, 'image_features') | |||
| plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png') | |||
| plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png') | |||
| # saving_in_tensorboard(config, text_features, targets, 'text_features') | |||
| plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png') | |||
| plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png') | |||
| # | |||
| # saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features') | |||
| plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png') | |||
| plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png') | |||
| # saving_in_tensorboard(config, concat_features, targets, 'concat_features') | |||
| plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png') | |||
| plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png') | |||
| config_parameters = str(config) | |||
| with open(config.output_path + '/parameters.txt', 'w') as f: | |||
| f.write(config_parameters) | |||
| print(config) | |||
| def test_main(config, trial_number=None): | |||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| test(config, test_loader, trial_number) | |||
| @@ -0,0 +1,47 @@ | |||
| from torch import nn | |||
| from transformers import BertModel, BertConfig, BertTokenizer, \ | |||
| BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||
| XLNetModel, XLNetConfig, XLNetTokenizer | |||
| class TextEncoder(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.config = config | |||
| self.model = get_text_model(config) | |||
| for p in self.model.parameters(): | |||
| p.requires_grad = self.config.trainable | |||
| self.target_token_idx = 0 | |||
| self.text_encoder_embedding = dict() | |||
| def forward(self, ids, input_ids, attention_mask): | |||
| output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |||
| last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||
| # for i, id in enumerate(ids): | |||
| # id = int(id.detach().cpu().numpy()) | |||
| # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||
| return last_hidden_state | |||
| def get_text_model(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| if config.pretrained: | |||
| model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||
| else: | |||
| model = BigBirdModel(config=BigBirdConfig()) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| if config.pretrained: | |||
| model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
| output_hidden_states=True, return_dict=True) | |||
| else: | |||
| model = XLNetModel(config=XLNetConfig()) | |||
| else: | |||
| if config.pretrained: | |||
| model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
| output_hidden_states=True, return_dict=True) | |||
| else: | |||
| model = BertModel(config=BertConfig()) | |||
| return model | |||
| @@ -0,0 +1,19 @@ | |||
| from data_loaders import build_loaders, make_dfs | |||
| from learner import supervised_train | |||
| from test_main import test | |||
| def torch_main(config): | |||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
| train_loader = build_loaders(config, train_df, mode="train") | |||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| supervised_train(config, train_loader, validation_loader) | |||
| test(config, test_loader) | |||
| @@ -0,0 +1,93 @@ | |||
| import numpy as np | |||
| import torch | |||
| class AvgMeter: | |||
| def __init__(self, name="Metric"): | |||
| self.name = name | |||
| self.reset() | |||
| def reset(self): | |||
| self.avg, self.sum, self.count = [0] * 3 | |||
| def update(self, val, count=1): | |||
| self.count += count | |||
| self.sum += val * count | |||
| self.avg = self.sum / self.count | |||
| def __repr__(self): | |||
| text = f"{self.name}: {self.avg:.4f}" | |||
| return text | |||
| def print_lr(optimizer): | |||
| for param_group in optimizer.param_groups: | |||
| print(param_group['name'], param_group['lr']) | |||
| class CheckpointSaving: | |||
| def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print): | |||
| self.best_score = None | |||
| self.val_acc_max = 0 | |||
| self.path = path | |||
| self.verbose = verbose | |||
| self.trace_func = trace_func | |||
| def __call__(self, val_acc, model): | |||
| if self.best_score is None: | |||
| self.best_score = val_acc | |||
| self.save_checkpoint(val_acc, model) | |||
| elif val_acc > self.best_score: | |||
| self.best_score = val_acc | |||
| self.save_checkpoint(val_acc, model) | |||
| def save_checkpoint(self, val_acc, model): | |||
| if self.verbose: | |||
| self.trace_func( | |||
| f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...') | |||
| torch.save(model.state_dict(), self.path) | |||
| self.val_acc_max = val_acc | |||
| class EarlyStopping: | |||
| def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print): | |||
| self.patience = patience | |||
| self.verbose = verbose | |||
| self.counter = 0 | |||
| self.best_score = None | |||
| self.early_stop = False | |||
| self.val_loss_min = np.Inf | |||
| self.delta = delta | |||
| self.path = path | |||
| self.trace_func = trace_func | |||
| def __call__(self, val_loss, model): | |||
| score = -val_loss | |||
| if self.best_score is None: | |||
| self.best_score = score | |||
| if self.verbose: | |||
| self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||
| self.val_loss_min = val_loss | |||
| # self.save_checkpoint(val_loss, model) | |||
| elif score < self.best_score + self.delta: | |||
| self.counter += 1 | |||
| self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |||
| if self.counter >= self.patience: | |||
| self.early_stop = True | |||
| else: | |||
| self.best_score = score | |||
| if self.verbose: | |||
| self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||
| # self.save_checkpoint(val_loss, model) | |||
| self.val_loss_min = val_loss | |||
| self.counter = 0 | |||
| # def save_checkpoint(self, val_loss, model): | |||
| # if self.verbose: | |||
| # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||
| # torch.save(model.state_dict(), self.path) | |||
| # self.val_loss_min = val_loss | |||