| # Pyre type checker | # Pyre type checker | ||||
| .pyre/ | .pyre/ | ||||
| .idea/ | |||||
| env/ | |||||
| venv/ |
| # FakeNewsRevealer | # FakeNewsRevealer | ||||
| Official implementation of the Fake News Revealer paper | |||||
| Official implementation of the "Fake News Revealer (FNR): A Similarity and Transformer-Based Approach to Detect | |||||
| Multi-Modal Fake News in Social Media" paper [(ArXiv Link)](https://arxiv.org/pdf/2112.01131.pdf). | |||||
| ## Requirments | |||||
| FNR is built in Python 3.6 using PyTorch 1.8. Please use the following command to install the requirements: | |||||
| ``` | |||||
| pip install -r requirements.txt | |||||
| ``` | |||||
| ## How to Run | |||||
| First, place the data address and configuration into the config file in the data directory, and then follow the train | |||||
| and test commands. | |||||
| ### train | |||||
| To run with Optuna for parameter tuning use this command: | |||||
| ``` | |||||
| python main --data "DATA NAME" --use_optuna "NUMBER OF OPTUNA TRIALS" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||||
| ``` | |||||
| To run without parameter tuning, adjust your parameters in the config file and then use the below command: | |||||
| ``` | |||||
| python main --data "DATA NAME" --batch "BATCH SIZE" --epoch "EPOCHS NUMBER" | |||||
| ``` | |||||
| ### test | |||||
| In the test step, at first, make sure to have the requested 'checkpoint' file then run the following line: | |||||
| ``` | |||||
| python main --data "DATA NAME" --just_test "REQESTED TRIAL NUMBER" | |||||
| ``` | |||||
| ## Bibtex | |||||
| Cite our paper using the following bibtex item: | |||||
| ``` | |||||
| @misc{ghorbanpour2021fnr, | |||||
| title={FNR: A Similarity and Transformer-Based Approach to Detect Multi-Modal Fake News in Social Media}, | |||||
| author={Faeze Ghorbanpour and Maryam Ramezani and Mohammad A. Fazli and Hamid R. Rabiee}, | |||||
| year={2021}, | |||||
| eprint={2112.01131}, | |||||
| archivePrefix={arXiv}, | |||||
| } | |||||
| ``` | |||||
| import torch | |||||
| import ast | |||||
| class Config: | |||||
| name = '' | |||||
| DatasetLoader = None | |||||
| debug = False | |||||
| data_path = '' | |||||
| output_path = '' | |||||
| train_image_path = '' | |||||
| validation_image_path = '' | |||||
| test_image_path = '' | |||||
| train_text_path = '' | |||||
| validation_text_path = '' | |||||
| test_text_path = '' | |||||
| train_text_embedding_path = '' | |||||
| validation_text_embedding_path = '' | |||||
| train_image_embedding_path = '' | |||||
| validation_image_embedding_path = '' | |||||
| batch_size = 256 | |||||
| epochs = 50 | |||||
| num_workers = 2 | |||||
| head_lr = 1e-03 | |||||
| image_encoder_lr = 1e-04 | |||||
| text_encoder_lr = 1e-04 | |||||
| attention_lr = 1e-3 | |||||
| classification_lr = 1e-03 | |||||
| max_grad_norm = 5.0 | |||||
| head_weight_decay = 0.001 | |||||
| attention_weight_decay = 0.001 | |||||
| classification_weight_decay = 0.001 | |||||
| patience = 30 | |||||
| delta = 0.0000001 | |||||
| factor = 0.8 | |||||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
| image_embedding = 768 | |||||
| num_img_region = 64 # 16 #TODO | |||||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| text_embedding = 768 | |||||
| max_length = 32 | |||||
| pretrained = True # for both image encoder and text encoder | |||||
| trainable = False # for both image encoder and text encoder | |||||
| temperature = 1.0 | |||||
| # image size | |||||
| size = 224 | |||||
| num_projection_layers = 1 | |||||
| projection_size = 256 | |||||
| dropout = 0.3 | |||||
| hidden_size = 256 | |||||
| num_region = 64 # 16 #TODO | |||||
| region_size = projection_size // num_region | |||||
| classes = ['real', 'fake'] | |||||
| class_num = 2 | |||||
| loss_weight = 1 | |||||
| class_weights = [1, 1] | |||||
| writer = None | |||||
| has_unlabeled_data = False | |||||
| step = 0 | |||||
| T1 = 10 | |||||
| T2 = 150 | |||||
| af = 3 | |||||
| wanted_accuracy = 0.76 | |||||
| def optuna(self, trial): | |||||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
| self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1) | |||||
| self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
| self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
| self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
| def __str__(self): | |||||
| s = '' | |||||
| members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")] | |||||
| for member in members: | |||||
| s += member + '\t' + str(getattr(self, member)) + '\n' | |||||
| return s | |||||
| def assign_hyperparameters(self, s): | |||||
| for line in s.split('\n'): | |||||
| s = line.split('\t') | |||||
| try: | |||||
| attr = getattr(self, s[0]) | |||||
| if type(attr) not in [list, set, dict]: | |||||
| setattr(self, s[0], type(attr)(s[1])) | |||||
| else: | |||||
| setattr(self, s[0], ast.literal_eval(s[1])) | |||||
| except: | |||||
| print(s[0], 'doesnot exist') |
| import torch | |||||
| from torch.utils.data import Dataset | |||||
| class DatasetLoader(Dataset): | |||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| self.config = config | |||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer( | |||||
| list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||||
| ) | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | |||||
| def set_text(self, idx): | |||||
| item = { | |||||
| key: values[idx].clone().detach() | |||||
| for key, values in self.encoded_text.items() | |||||
| } | |||||
| item['text'] = self.text[idx] | |||||
| item['label'] = self.labels[idx] | |||||
| item['id'] = idx | |||||
| return item | |||||
| def set_image(self, image, item): | |||||
| if 'resnet' in self.config.image_model_name: | |||||
| image = self.transforms(image=image)['image'] | |||||
| item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||||
| else: | |||||
| image = self.transforms(images=image, return_tensors='pt') | |||||
| image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||||
| item['image'] = image.reshape((3, 224, 224)) |
| import torch | |||||
| from data.config import Config | |||||
| from data.twitter.data_loader import TwitterDatasetLoader | |||||
| class TwitterConfig(Config): | |||||
| name = 'twitter' | |||||
| DatasetLoader = TwitterDatasetLoader | |||||
| data_path = '/twitter/' | |||||
| output_path = '' | |||||
| train_image_path = data_path + 'images_train/' | |||||
| validation_image_path = data_path + 'images_validation/' | |||||
| test_image_path = data_path + 'images_test/' | |||||
| train_text_path = data_path + 'twitter_train_translated.csv' | |||||
| validation_text_path = data_path + 'twitter_validation_translated.csv' | |||||
| test_text_path = data_path + 'twitter_test_translated.csv' | |||||
| batch_size = 128 | |||||
| epochs = 100 | |||||
| num_workers = 2 | |||||
| head_lr = 1e-03 | |||||
| image_encoder_lr = 1e-04 | |||||
| text_encoder_lr = 1e-04 | |||||
| attention_lr = 1e-3 | |||||
| classification_lr = 1e-03 | |||||
| head_weight_decay = 0.001 | |||||
| attention_weight_decay = 0.001 | |||||
| classification_weight_decay = 0.001 | |||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
| image_model_name = 'vit-base-patch16-224' | |||||
| image_embedding = 768 | |||||
| text_encoder_model = "bert-base-uncased" | |||||
| text_tokenizer = "bert-base-uncased" | |||||
| text_embedding = 768 | |||||
| max_length = 32 | |||||
| pretrained = True | |||||
| trainable = False | |||||
| temperature = 1.0 | |||||
| classes = ['real', 'fake'] | |||||
| class_weights = [1, 1] | |||||
| wanted_accuracy = 0.76 | |||||
| def optuna(self, trial): | |||||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
| # self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) |
| import cv2 | |||||
| import pickle | |||||
| from PIL import Image | |||||
| from numpy import asarray | |||||
| from data.data_loader import DatasetLoader | |||||
| class TwitterDatasetLoader(DatasetLoader): | |||||
| def __getitem__(self, idx): | |||||
| item = self.set_text(idx) | |||||
| try: | |||||
| if self.mode == 'train': | |||||
| image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||||
| elif self.mode == 'validation': | |||||
| image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}") | |||||
| else: | |||||
| image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}") | |||||
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||||
| except: | |||||
| image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
| image = asarray(image) | |||||
| self.set_image(image, item) | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) | |||||
| class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
| self.config = config | |||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | |||||
| if mode == 'train': | |||||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| else: | |||||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| def __getitem__(self, idx): | |||||
| item = dict() | |||||
| item['id'] = idx | |||||
| item['image_embedding'] = self.image_embedding[idx] | |||||
| item['text_embedding'] = self.text_embedding[idx] | |||||
| item['label'] = self.labels[idx] | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) |
| import torch | |||||
| from transformers import BertTokenizer, BertModel, BertConfig | |||||
| from data.config import Config | |||||
| from data.weibo.data_loader import WeiboDatasetLoader | |||||
| class WeiboConfig(Config): | |||||
| name = 'weibo' | |||||
| DatasetLoader = WeiboDatasetLoader | |||||
| data_path = 'weibo/' | |||||
| output_path = '' | |||||
| rumor_image_path = data_path + 'rumor_images/' | |||||
| nonrumor_image_path = data_path + 'nonrumor_images/' | |||||
| train_text_path = data_path + 'weibo_train.csv' | |||||
| validation_text_path = data_path + 'weibo_validation.csv' | |||||
| test_text_path = data_path + 'weibo_test.csv' | |||||
| batch_size = 128 | |||||
| epochs = 100 | |||||
| num_workers = 2 | |||||
| head_lr = 1e-03 | |||||
| image_encoder_lr = 1e-02 | |||||
| text_encoder_lr = 1e-05 | |||||
| weight_decay = 0.001 | |||||
| classification_lr = 1e-02 | |||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
| image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||||
| image_embedding = 768 # 2048 | |||||
| num_img_region = 64 # TODO | |||||
| text_encoder_model = "bert-base-chinese" | |||||
| text_tokenizer = "bert-base-chinese" | |||||
| text_embedding = 768 | |||||
| max_length = 200 | |||||
| pretrained = True | |||||
| trainable = False | |||||
| temperature = 1.0 | |||||
| labels = ['real', 'fake'] | |||||
| wanted_accuracy = 0.80 | |||||
| def optuna(self, trial): | |||||
| self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1) | |||||
| self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3) | |||||
| self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3) | |||||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
| import pickle | |||||
| import cv2 | |||||
| from PIL import Image | |||||
| from numpy import asarray | |||||
| from data.data_loader import DatasetLoader | |||||
| class WeiboDatasetLoader(DatasetLoader): | |||||
| def __getitem__(self, idx): | |||||
| item = self.set_text(idx) | |||||
| if self.labels[idx] == 1: | |||||
| try: | |||||
| image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||||
| except: | |||||
| image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
| image = asarray(image) | |||||
| else: | |||||
| try: | |||||
| image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}") | |||||
| except: | |||||
| image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||||
| image = asarray(image) | |||||
| self.set_image(image, item) | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) | |||||
| class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
| self.config = config | |||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | |||||
| if mode == 'train': | |||||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| else: | |||||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| def __getitem__(self, idx): | |||||
| item = dict() | |||||
| item['id'] = idx | |||||
| item['image_embedding'] = self.image_embedding[idx] | |||||
| item['text_embedding'] = self.text_embedding[idx] | |||||
| item['label'] = self.labels[idx] | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) |
| import albumentations as A | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| from sklearn.utils import compute_class_weight | |||||
| from torch.utils.data import DataLoader | |||||
| from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
| def get_transforms(config, mode="train"): | |||||
| if mode == "train": | |||||
| return A.Compose( | |||||
| [ | |||||
| A.Resize(config.size, config.size, always_apply=True), | |||||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
| ] | |||||
| ) | |||||
| else: | |||||
| return A.Compose( | |||||
| [ | |||||
| A.Resize(config.size, config.size, always_apply=True), | |||||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
| ] | |||||
| ) | |||||
| def make_dfs(config): | |||||
| train_dataframe = pd.read_csv(config.train_text_path) | |||||
| train_dataframe.dropna(subset=['text'], inplace=True) | |||||
| train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True) | |||||
| train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
| config.class_weights = get_class_weights(train_dataframe.label.values) | |||||
| if config.test_text_path is None: | |||||
| offset = int(train_dataframe.shape[0] * 0.80) | |||||
| test_dataframe = train_dataframe[offset:] | |||||
| train_dataframe = train_dataframe[:offset] | |||||
| else: | |||||
| test_dataframe = pd.read_csv(config.test_text_path) | |||||
| test_dataframe.dropna(subset=['text'], inplace=True) | |||||
| test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True) | |||||
| test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
| if config.validation_text_path is None: | |||||
| offset = int(train_dataframe.shape[0] * 0.90) | |||||
| validation_dataframe = train_dataframe[offset:] | |||||
| train_dataframe = train_dataframe[:offset] | |||||
| else: | |||||
| validation_dataframe = pd.read_csv(config.validation_text_path) | |||||
| validation_dataframe.dropna(subset=['text'], inplace=True) | |||||
| validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||||
| validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||||
| unlabeled_dataframe = None | |||||
| if config.has_unlabeled_data: | |||||
| unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||||
| unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||||
| return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||||
| def get_tokenizer(config): | |||||
| if 'bigbird' in config.text_encoder_model: | |||||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
| elif 'xlnet' in config.text_encoder_model: | |||||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
| else: | |||||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
| return tokenizer | |||||
| def build_loaders(config, dataframe, mode): | |||||
| if 'resnet' in config.image_model_name: | |||||
| transforms = get_transforms(config, mode=mode) | |||||
| else: | |||||
| transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
| dataset = config.DatasetLoader( | |||||
| config, | |||||
| dataframe["image"].values, | |||||
| dataframe["text"].values, | |||||
| dataframe["label"].values, | |||||
| tokenizer=get_tokenizer(config), | |||||
| transforms=transforms, | |||||
| mode=mode, | |||||
| ) | |||||
| if mode != 'train': | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| batch_size=config.batch_size, | |||||
| num_workers=config.num_workers, | |||||
| pin_memory=False, | |||||
| shuffle=False, | |||||
| ) | |||||
| else: | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| batch_size=config.batch_size, | |||||
| num_workers=config.num_workers, | |||||
| pin_memory=True, | |||||
| shuffle=True, | |||||
| ) | |||||
| return dataloader | |||||
| def get_class_weights(y): | |||||
| class_weights = compute_class_weight('balanced', np.unique(y), y) | |||||
| return class_weights |
| from itertools import cycle | |||||
| import matplotlib.pyplot as plt | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| import seaborn as sns | |||||
| from numpy import interp | |||||
| from sklearn.decomposition import PCA | |||||
| from sklearn.manifold import TSNE | |||||
| from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, average_precision_score | |||||
| from sklearn.metrics import classification_report, roc_curve, auc | |||||
| from sklearn.preprocessing import OneHotEncoder | |||||
| def metrics(truth, pred, prob, file_path): | |||||
| truth = [i.cpu().numpy() for i in truth] | |||||
| pred = [i.cpu().numpy() for i in pred] | |||||
| prob = [i.cpu().numpy() for i in prob] | |||||
| pred = np.concatenate(pred, axis=0) | |||||
| truth = np.concatenate(truth, axis=0) | |||||
| prob = np.concatenate(prob, axis=0) | |||||
| prob = prob[:, 1] | |||||
| f_score_micro = f1_score(truth, pred, average='micro', zero_division=0) | |||||
| f_score_macro = f1_score(truth, pred, average='macro', zero_division=0) | |||||
| f_score_weighted = f1_score(truth, pred, average='weighted', zero_division=0) | |||||
| accuarcy = accuracy_score(truth, pred) | |||||
| s = '' | |||||
| print('accuracy', accuarcy) | |||||
| s += 'accuracy' + str(accuarcy) + '\n' | |||||
| print('f_score_micro', f_score_micro) | |||||
| s += 'f_score_micro' + str(f_score_micro) + '\n' | |||||
| print('f_score_macro', f_score_macro) | |||||
| s += 'f_score_macro' + str(f_score_macro) + '\n' | |||||
| print('f_score_weighted', f_score_weighted) | |||||
| s += 'f_score_weighted' + str(f_score_weighted) + '\n' | |||||
| fpr, tpr, thresholds = roc_curve(truth, prob) | |||||
| AUC = auc(fpr, tpr) | |||||
| print('AUC', AUC) | |||||
| s += 'AUC' + str(AUC) + '\n' | |||||
| df = pd.DataFrame(dict(fpr=fpr, tpr=tpr)) | |||||
| df.to_csv(file_path) | |||||
| return s | |||||
| def report_per_class(truth, pred): | |||||
| truth = [i.cpu().numpy() for i in truth] | |||||
| pred = [i.cpu().numpy() for i in pred] | |||||
| pred = np.concatenate(pred, axis=0) | |||||
| truth = np.concatenate(truth, axis=0) | |||||
| report = classification_report(truth, pred, zero_division=0, output_dict=True) | |||||
| s = '' | |||||
| class_labels = [k for k in report.keys() if k not in ['micro avg', 'macro avg', 'weighted avg', 'samples avg']] | |||||
| for class_label in class_labels: | |||||
| print('class_label', class_label) | |||||
| s += 'class_label' + str(class_label) + '\n' | |||||
| s += str(report[class_label]) | |||||
| print(report[class_label]) | |||||
| return s | |||||
| def multiclass_acc(truth, pred): | |||||
| truth = [i.cpu().numpy() for i in truth] | |||||
| pred = [i.cpu().numpy() for i in pred] | |||||
| pred = np.concatenate(pred, axis=0) | |||||
| truth = np.concatenate(truth, axis=0) | |||||
| return accuracy_score(truth, pred) | |||||
| def roc_auc_plot(truth, score, num_class=2, fname='roc.png'): | |||||
| truth = [i.cpu().numpy() for i in truth] | |||||
| score = [i.cpu().numpy() for i in score] | |||||
| truth = np.concatenate(truth, axis=0) | |||||
| score = np.concatenate(score, axis=0) | |||||
| enc = OneHotEncoder(handle_unknown='ignore') | |||||
| enc.fit(truth.reshape(-1, 1)) | |||||
| label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||||
| fpr_dict = dict() | |||||
| tpr_dict = dict() | |||||
| roc_auc_dict = dict() | |||||
| for i in range(num_class): | |||||
| fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score[:, i]) | |||||
| roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) | |||||
| # micro | |||||
| fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score.ravel()) | |||||
| roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) | |||||
| # macro | |||||
| all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)])) | |||||
| mean_tpr = np.zeros_like(all_fpr) | |||||
| for i in range(num_class): | |||||
| mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) | |||||
| mean_tpr /= num_class | |||||
| fpr_dict["macro"] = all_fpr | |||||
| tpr_dict["macro"] = mean_tpr | |||||
| roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) | |||||
| plt.figure() | |||||
| lw = 2 | |||||
| plt.plot(fpr_dict["micro"], tpr_dict["micro"], | |||||
| label='micro-average ROC curve (area = {0:0.2f})' | |||||
| ''.format(roc_auc_dict["micro"]), | |||||
| color='deeppink', linestyle=':', linewidth=4) | |||||
| plt.plot(fpr_dict["macro"], tpr_dict["macro"], | |||||
| label='macro-average ROC curve (area = {0:0.2f})' | |||||
| ''.format(roc_auc_dict["macro"]), | |||||
| color='navy', linestyle=':', linewidth=4) | |||||
| colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||||
| for i, color in zip(range(num_class), colors): | |||||
| plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw, | |||||
| label='ROC curve of class {0} (area = {1:0.2f})' | |||||
| ''.format(i, roc_auc_dict[i])) | |||||
| plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||||
| plt.xlim([0.0, 1.0]) | |||||
| plt.ylim([0.0, 1.05]) | |||||
| plt.xlabel('False Positive Rate') | |||||
| plt.ylabel('True Positive Rate') | |||||
| plt.legend(loc="lower right") | |||||
| plt.savefig(fname) | |||||
| # plt.show() | |||||
| def precision_recall_plot(truth, score, num_class=2, fname='pr.png'): | |||||
| truth = [i.cpu().numpy() for i in truth] | |||||
| score = [i.cpu().numpy() for i in score] | |||||
| truth = np.concatenate(truth, axis=0) | |||||
| score = np.concatenate(score, axis=0) | |||||
| enc = OneHotEncoder(handle_unknown='ignore') | |||||
| enc.fit(truth.reshape(-1, 1)) | |||||
| label_onehot = enc.transform(truth.reshape(-1, 1)).toarray() | |||||
| # Call the Sklearn library, calculate the precision and recall corresponding to each category | |||||
| precision_dict = dict() | |||||
| recall_dict = dict() | |||||
| average_precision_dict = dict() | |||||
| for i in range(num_class): | |||||
| precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score[:, i]) | |||||
| average_precision_dict[i] = average_precision_score(label_onehot[:, i], score[:, i]) | |||||
| print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i]) | |||||
| # micro | |||||
| precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(), | |||||
| score.ravel()) | |||||
| average_precision_dict["micro"] = average_precision_score(label_onehot, score, average="micro") | |||||
| # macro | |||||
| all_fpr = np.unique(np.concatenate([precision_dict[i] for i in range(num_class)])) | |||||
| mean_tpr = np.zeros_like(all_fpr) | |||||
| for i in range(num_class): | |||||
| mean_tpr += interp(all_fpr, precision_dict[i], recall_dict[i]) | |||||
| mean_tpr /= num_class | |||||
| precision_dict["macro"] = all_fpr | |||||
| recall_dict["macro"] = mean_tpr | |||||
| average_precision_dict["macro"] = auc(precision_dict["macro"], recall_dict["macro"]) | |||||
| plt.figure() | |||||
| plt.subplots(figsize=(16, 10)) | |||||
| lw = 2 | |||||
| plt.plot(precision_dict["micro"], recall_dict["micro"], | |||||
| label='micro-average Precision-Recall curve (area = {0:0.2f})' | |||||
| ''.format(average_precision_dict["micro"]), | |||||
| color='deeppink', linestyle=':', linewidth=4) | |||||
| plt.plot(precision_dict["macro"], recall_dict["macro"], | |||||
| label='macro-average Precision-Recall curve (area = {0:0.2f})' | |||||
| ''.format(average_precision_dict["macro"]), | |||||
| color='navy', linestyle=':', linewidth=4) | |||||
| colors = cycle(['aqua', 'darkorange', 'cornflowerblue']) | |||||
| for i, color in zip(range(num_class), colors): | |||||
| plt.plot(precision_dict[i], recall_dict[i], color=color, lw=lw, | |||||
| label='Precision-Recall curve of class {0} (area = {1:0.2f})' | |||||
| ''.format(i, average_precision_dict[i])) | |||||
| plt.plot([0, 1], [0, 1], 'k--', lw=lw) | |||||
| plt.xlabel('Recall') | |||||
| plt.ylabel('Precision') | |||||
| plt.ylim([0.0, 1.05]) | |||||
| plt.xlim([0.0, 1.0]) | |||||
| plt.legend(loc="lower left") | |||||
| plt.savefig(fname=fname) | |||||
| # plt.show() | |||||
| def saving_in_tensorboard(config, x, y, fname='embedding'): | |||||
| x = [i.cpu().numpy() for i in x] | |||||
| y = [i.cpu().numpy() for i in y] | |||||
| x = np.concatenate(x, axis=0) | |||||
| y = np.concatenate(y, axis=0) | |||||
| z = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
| # config.writer.add_embedding(mat=x, label_img=y, metadata=z, tag=fname) | |||||
| def plot_tsne(config, x, y, fname='tsne.png'): | |||||
| x = [i.cpu().numpy() for i in x] | |||||
| y = [i.cpu().numpy() for i in y] | |||||
| x = np.concatenate(x, axis=0) | |||||
| y = np.concatenate(y, axis=0) | |||||
| y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
| tsne = TSNE(n_components=2, verbose=1, init="pca", perplexity=10, learning_rate=1000) | |||||
| tsne_proj = tsne.fit_transform(x) | |||||
| fig, ax = plt.subplots(figsize=(16, 10)) | |||||
| palette = sns.color_palette("bright", 2) | |||||
| sns.scatterplot(tsne_proj[:, 0], tsne_proj[:, 1], hue=y, legend='full', palette=palette) | |||||
| ax.legend(fontsize='large', markerscale=2) | |||||
| plt.title('tsne of ' + str(fname.split('/')[-1].split('.')[0])) | |||||
| plt.savefig(fname=fname) | |||||
| # plt.show() | |||||
| def plot_pca(config, x, y, fname='pca.png'): | |||||
| x = [i.cpu().numpy() for i in x] | |||||
| y = [i.cpu().numpy() for i in y] | |||||
| x = np.concatenate(x, axis=0) | |||||
| y = np.concatenate(y, axis=0) | |||||
| y = pd.DataFrame(y)[0].apply(lambda i: config.classes[i]).values | |||||
| pca = PCA(n_components=2) | |||||
| pca_proj = pca.fit_transform(x) | |||||
| fig, ax = plt.subplots(figsize=(16, 10)) | |||||
| palette = sns.color_palette("bright", 2) | |||||
| sns.scatterplot(pca_proj[:, 0], pca_proj[:, 1], hue=y, legend='full', palette=palette) | |||||
| ax.legend(fontsize='large', markerscale=2) | |||||
| plt.title('pca of ' + str(fname.split('/')[-1].split('.')[0])) | |||||
| plt.savefig(fname=fname) | |||||
| # plt.show() |
| import timm | |||||
| from torch import nn | |||||
| from transformers import ViTModel, ViTConfig | |||||
| class ImageTransformerEncoder(nn.Module): | |||||
| def __init__(self, config): | |||||
| super().__init__() | |||||
| if config.pretrained: | |||||
| self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False, | |||||
| output_hidden_states=True, return_dict=True) | |||||
| else: | |||||
| self.model = ViTModel(config=ViTConfig()) | |||||
| for p in self.model.parameters(): | |||||
| p.requires_grad = config.trainable | |||||
| self.target_token_idx = 0 | |||||
| self.image_encoder_embedding = dict() | |||||
| def forward(self, ids, image): | |||||
| output = self.model(image) | |||||
| last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||||
| # for i, id in enumerate(ids): | |||||
| # id = int(id.detach().cpu().numpy()) | |||||
| # self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||||
| return last_hidden_state | |||||
| class ImageResnetEncoder(nn.Module): | |||||
| def __init__( | |||||
| self, config | |||||
| ): | |||||
| super().__init__() | |||||
| self.model = timm.create_model( | |||||
| config.image_model_name, config.pretrained, num_classes=0, global_pool="avg" | |||||
| ) | |||||
| for p in self.model.parameters(): | |||||
| p.requires_grad = config.trainable | |||||
| def forward(self, ids, image): | |||||
| return self.model(image) |
| import gc | |||||
| import itertools | |||||
| import random | |||||
| import numpy as np | |||||
| import optuna | |||||
| import pandas as pd | |||||
| import torch | |||||
| from evaluation import multiclass_acc | |||||
| from model import FakeNewsModel, calculate_loss | |||||
| from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving | |||||
| def batch_constructor(config, batch): | |||||
| b = {} | |||||
| for k, v in batch.items(): | |||||
| if k != 'text': | |||||
| b[k] = v.to(config.device) | |||||
| return b | |||||
| def train_epoch(config, model, train_loader, optimizer, scalar): | |||||
| loss_meter = AvgMeter('train') | |||||
| # tqdm_object = tqdm(train_loader, total=len(train_loader)) | |||||
| targets = [] | |||||
| predictions = [] | |||||
| for index, batch in enumerate(train_loader): | |||||
| batch = batch_constructor(config, batch) | |||||
| optimizer.zero_grad(set_to_none=True) | |||||
| with torch.cuda.amp.autocast(): | |||||
| output, score = model(batch) | |||||
| loss = calculate_loss(model, score, batch['label']) | |||||
| scalar.scale(loss).backward() | |||||
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |||||
| if (index + 1) % 2: | |||||
| scalar.step(optimizer) | |||||
| # loss.backward() | |||||
| # optimizer.step() | |||||
| scalar.update() | |||||
| count = batch["id"].size(0) | |||||
| loss_meter.update(loss.detach(), count) | |||||
| # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |||||
| prediction = output.detach() | |||||
| predictions.append(prediction) | |||||
| target = batch['label'].detach() | |||||
| targets.append(target) | |||||
| return loss_meter, targets, predictions | |||||
| def validation_epoch(config, model, validation_loader): | |||||
| loss_meter = AvgMeter('validation') | |||||
| targets = [] | |||||
| predictions = [] | |||||
| # tqdm_object = tqdm(validation_loader, total=len(validation_loader)) | |||||
| for batch in validation_loader: | |||||
| batch = batch_constructor(config, batch) | |||||
| with torch.no_grad(): | |||||
| output, score = model(batch) | |||||
| loss = calculate_loss(model, score, batch['label']) | |||||
| count = batch["id"].size(0) | |||||
| loss_meter.update(loss.detach(), count) | |||||
| # tqdm_object.set_postfix(validation_loss=loss_meter.avg) | |||||
| prediction = output.detach() | |||||
| predictions.append(prediction) | |||||
| target = batch['label'].detach() | |||||
| targets.append(target) | |||||
| return loss_meter, targets, predictions | |||||
| def supervised_train(config, train_loader, validation_loader, trial=None): | |||||
| torch.cuda.empty_cache() | |||||
| checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt' | |||||
| if trial: | |||||
| checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt' | |||||
| torch.manual_seed(27) | |||||
| random.seed(27) | |||||
| np.random.seed(27) | |||||
| torch.autograd.set_detect_anomaly(False) | |||||
| torch.autograd.profiler.profile(False) | |||||
| torch.autograd.profiler.emit_nvtx(False) | |||||
| scalar = torch.cuda.amp.GradScaler() | |||||
| model = FakeNewsModel(config).to(config.device) | |||||
| params = [ | |||||
| {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'}, | |||||
| {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'}, | |||||
| {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()), | |||||
| "lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'}, | |||||
| {"params": model.classifier.parameters(), "lr": config.classification_lr, | |||||
| "weight_decay": config.classification_weight_decay, | |||||
| 'name': 'classifier'} | |||||
| ] | |||||
| optimizer = torch.optim.AdamW(params, amsgrad=True) | |||||
| lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor, | |||||
| patience=config.patience // 5, verbose=True) | |||||
| early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True) | |||||
| checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True) | |||||
| train_losses, train_accuracies = [], [] | |||||
| validation_losses, validation_accuracies = [], [] | |||||
| validation_accuracy, validation_loss = 0, 1 | |||||
| for epoch in range(config.epochs): | |||||
| print(f"Epoch: {epoch + 1}") | |||||
| gc.collect() | |||||
| model.train() | |||||
| train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar) | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||||
| train_accuracy = multiclass_acc(train_truth, train_pred) | |||||
| validation_accuracy = multiclass_acc(validation_truth, validation_pred) | |||||
| print_lr(optimizer) | |||||
| print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy) | |||||
| print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy) | |||||
| if lr_scheduler: | |||||
| lr_scheduler.step(validation_loss.avg) | |||||
| if early_stopping: | |||||
| early_stopping(validation_loss.avg, model) | |||||
| if early_stopping.early_stop: | |||||
| print("Early stopping") | |||||
| break | |||||
| if checkpoint_saving: | |||||
| checkpoint_saving(validation_accuracy, model) | |||||
| train_accuracies.append(train_accuracy) | |||||
| train_losses.append(train_loss) | |||||
| validation_accuracies.append(validation_accuracy) | |||||
| validation_losses.append(validation_loss) | |||||
| if trial: | |||||
| trial.report(validation_accuracy, epoch) | |||||
| if trial.should_prune(): | |||||
| print('trial pruned') | |||||
| raise optuna.exceptions.TrialPruned() | |||||
| print() | |||||
| if checkpoint_saving: | |||||
| model = FakeNewsModel(config).to(config.device) | |||||
| model.load_state_dict(torch.load(checkpoint_path)) | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader) | |||||
| validation_accuracy = multiclass_acc(validation_pred, validation_truth) | |||||
| if trial and validation_accuracy >= config.wanted_accuracy: | |||||
| loss_accuracy = pd.DataFrame( | |||||
| {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||||
| 'validation_accuracy': validation_accuracies}) | |||||
| torch.save({'model_state_dict': model.state_dict(), | |||||
| 'parameters': str(config), | |||||
| 'optimizer_state_dict': optimizer.state_dict(), | |||||
| 'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||||
| if not checkpoint_saving: | |||||
| loss_accuracy = pd.DataFrame( | |||||
| {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses, | |||||
| 'validation_accuracy': validation_accuracies}) | |||||
| torch.save(model.state_dict(), checkpoint_path) | |||||
| if trial and validation_accuracy >= config.wanted_accuracy: | |||||
| torch.save({'model_state_dict': model.state_dict(), | |||||
| 'parameters': str(config), | |||||
| 'optimizer_state_dict': optimizer.state_dict(), | |||||
| 'loss_accuracy': loss_accuracy}, checkpoint_path2) | |||||
| try: | |||||
| del train_loss | |||||
| del train_truth | |||||
| del train_pred | |||||
| del validation_loss | |||||
| del validation_truth | |||||
| del validation_pred | |||||
| del train_losses | |||||
| del train_accuracies | |||||
| del validation_losses | |||||
| del validation_accuracies | |||||
| del loss_accuracy | |||||
| del scalar | |||||
| del model | |||||
| del params | |||||
| except: | |||||
| print('Error in deleting caches') | |||||
| pass | |||||
| return validation_accuracy | |||||
| import argparse | |||||
| from optuna_main import optuna_main | |||||
| from test_main import test_main | |||||
| from torch_main import torch_main | |||||
| from data.twitter.config import TwitterConfig | |||||
| from data.weibo.config import WeiboConfig | |||||
| from torch.utils.tensorboard import SummaryWriter | |||||
| if __name__ == '__main__': | |||||
| import os | |||||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('--data', type=str, required=True) | |||||
| parser.add_argument('--use_optuna', type=int, required=False) | |||||
| parser.add_argument('--just_test', type=int, required=False) | |||||
| parser.add_argument('--batch', type=int, required=False) | |||||
| parser.add_argument('--epoch', type=int, required=False) | |||||
| parser.add_argument('--extra', type=str, required=False) | |||||
| args = parser.parse_args() | |||||
| if args.data == 'twitter': | |||||
| config = TwitterConfig() | |||||
| elif args.data == 'weibo': | |||||
| config = WeiboConfig() | |||||
| else: | |||||
| raise Exception('Enter a valid dataset name', args.data) | |||||
| if args.batch: | |||||
| config.batch_size = args.batch | |||||
| if args.epoch: | |||||
| config.epochs = args.epoch | |||||
| if args.use_optuna: | |||||
| config.output_path += 'logs/' + args.data + '_optuna' + '_' + str(args.extra) | |||||
| else: | |||||
| config.output_path += 'logs/' + args.data + '_' + str(args.extra) | |||||
| use_optuna = True | |||||
| if not args.extra or 'temp' in args.extra: | |||||
| config.output_path = str(args.extra) | |||||
| use_optuna = False | |||||
| try: | |||||
| os.mkdir(config.output_path) | |||||
| except OSError: | |||||
| print("Creation of the directory failed") | |||||
| else: | |||||
| print("Successfully created the directory") | |||||
| config.writer = SummaryWriter(config.output_path) | |||||
| if args.use_optuna and use_optuna: | |||||
| optuna_main(config, args.use_optuna) | |||||
| elif args.just_test: | |||||
| test_main(config, args.just_test) | |||||
| else: | |||||
| torch_main(config) | |||||
| config.writer.close() |
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import nn | |||||
| from image import ImageTransformerEncoder, ImageResnetEncoder | |||||
| from text import TextEncoder | |||||
| class ProjectionHead(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| config, | |||||
| embedding_dim, | |||||
| ): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| self.projection = nn.Linear(embedding_dim, config.projection_size) | |||||
| self.gelu = nn.GELU() | |||||
| self.fc = nn.Linear(config.projection_size, config.projection_size) | |||||
| self.dropout = nn.Dropout(config.dropout) | |||||
| self.layer_norm = nn.LayerNorm(config.projection_size) | |||||
| def forward(self, x): | |||||
| projected = self.projection(x) | |||||
| x = self.gelu(projected) | |||||
| x = self.fc(x) | |||||
| x = self.dropout(x) | |||||
| x = x + projected | |||||
| x = self.layer_norm(x) | |||||
| return x | |||||
| class Classifier(nn.Module): | |||||
| def __init__(self, config): | |||||
| super().__init__() | |||||
| self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size) | |||||
| self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size) | |||||
| self.gelu = nn.GELU() | |||||
| self.drop_out = nn.Dropout(config.dropout) | |||||
| self.layer_norm_2 = nn.LayerNorm(config.hidden_size) | |||||
| self.classifier_layer = nn.Linear(config.hidden_size, config.class_num) | |||||
| self.softmax = nn.Softmax(dim=1) | |||||
| def forward(self, x): | |||||
| x = self.layer_norm_1(x) | |||||
| x = self.linear_layer(x) | |||||
| x = self.gelu(x) | |||||
| x = self.drop_out(x) | |||||
| self.embeddings = x = self.layer_norm_2(x) | |||||
| x = self.classifier_layer(x) | |||||
| x = self.softmax(x) | |||||
| return x | |||||
| class FakeNewsModel(nn.Module): | |||||
| def __init__( | |||||
| self, config | |||||
| ): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| if 'resnet' in self.config.image_model_name: | |||||
| self.image_encoder = ImageResnetEncoder(config) | |||||
| else: | |||||
| self.image_encoder = ImageTransformerEncoder(config) | |||||
| self.text_encoder = TextEncoder(config) | |||||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
| self.classifier = Classifier(config) | |||||
| self.temperature = config.temperature | |||||
| class_weights = torch.FloatTensor(config.class_weights) | |||||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
| self.text_embeddings = None | |||||
| self.image_embeddings = None | |||||
| self.multimodal_embeddings = None | |||||
| def forward(self, batch): | |||||
| image_features = self.image_encoder(ids=batch['id'], image=batch["image"]) | |||||
| text_features = self.text_encoder(ids=batch['id'], | |||||
| input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |||||
| self.image_embeddings = self.image_projection(image_features) | |||||
| self.text_embeddings = self.text_projection(text_features) | |||||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
| score = self.classifier(self.multimodal_embeddings) | |||||
| probs, output = torch.max(score.data, dim=1) | |||||
| return output, score | |||||
| class FakeNewsEmbeddingModel(nn.Module): | |||||
| def __init__( | |||||
| self, config | |||||
| ): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
| self.classifier = Classifier(config) | |||||
| self.temperature = config.temperature | |||||
| class_weights = torch.FloatTensor(config.class_weights) | |||||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
| def forward(self, batch): | |||||
| self.image_embeddings = self.image_projection(batch['image_embedding']) | |||||
| self.text_embeddings = self.text_projection(batch['text_embedding']) | |||||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
| score = self.classifier(self.multimodal_embeddings) | |||||
| probs, output = torch.max(score.data, dim=1) | |||||
| return output, score | |||||
| def calculate_loss(model, score, label): | |||||
| similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||||
| # fake = (label == 1).nonzero() | |||||
| # real = (label == 0).nonzero() | |||||
| # s_loss = 0 * similarity[fake].mean() + similarity[real].mean() | |||||
| c_loss = model.classifier_loss_function(score, label) | |||||
| loss = model.config.loss_weight * c_loss + similarity.mean() | |||||
| return loss | |||||
| def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||||
| # Calculating the Loss | |||||
| logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||||
| images_similarity = image_embeddings @ image_embeddings.T | |||||
| texts_similarity = text_embeddings @ text_embeddings.T | |||||
| targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||||
| texts_loss = cross_entropy(logits, targets, reduction='none') | |||||
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||||
| loss = (images_loss + texts_loss) / 2.0 | |||||
| return loss | |||||
| def cross_entropy(preds, targets, reduction='none'): | |||||
| log_softmax = nn.LogSoftmax(dim=-1) | |||||
| loss = (-targets * log_softmax(preds)).sum(1) | |||||
| if reduction == "none": | |||||
| return loss | |||||
| elif reduction == "mean": | |||||
| return loss.mean() |
| import joblib | |||||
| import optuna | |||||
| from optuna.pruners import MedianPruner | |||||
| from optuna.trial import TrialState | |||||
| from data_loaders import make_dfs, build_loaders | |||||
| from learner import supervised_train | |||||
| from test_main import test | |||||
| def objective(trial, config, train_loader, validation_loader): | |||||
| config.optuna(trial=trial) | |||||
| print('Trial', trial.number, 'parameters', trial.params) | |||||
| accuracy = supervised_train(config, train_loader, validation_loader, trial=trial) | |||||
| return accuracy | |||||
| def optuna_main(config, n_trials=100): | |||||
| train_df, test_df, validation_df, _ = make_dfs(config) | |||||
| train_loader = build_loaders(config, train_df, mode="train") | |||||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||||
| test_loader = build_loaders(config, test_df, mode="test") | |||||
| study = optuna.create_study(study_name=config.output_path.split('/')[-1], | |||||
| sampler=optuna.samplers.TPESampler(), | |||||
| storage=f'sqlite:///{config.output_path + "/optuna.db"}', | |||||
| load_if_exists=True, | |||||
| direction="maximize", | |||||
| pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=10) | |||||
| ) | |||||
| study.optimize(lambda trial: objective(trial, config, train_loader, validation_loader), n_trials=n_trials) | |||||
| pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) | |||||
| complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) | |||||
| joblib.dump(study, str(config.output_path) + '/study_optuna_model' + '.pkl') | |||||
| print("Study statistics: ") | |||||
| print(" Number of finished trials: ", len(study.trials)) | |||||
| print(" Number of pruned trials: ", len(pruned_trials)) | |||||
| print(" Number of complete trials: ", len(complete_trials)) | |||||
| s = '' | |||||
| print("Best trial:") | |||||
| trial = study.best_trial | |||||
| print(' Number: ', trial.number) | |||||
| print(" Value: ", trial.value) | |||||
| s += 'value: ' + str(trial.value) + '\n' | |||||
| print(" Params: ") | |||||
| s += 'params: \n' | |||||
| for key, value in trial.params.items(): | |||||
| print(" {}: {}".format(key, value)) | |||||
| s += " {}: {}\n".format(key, value) | |||||
| with open(config.output_path+'/optuna_results.txt', 'w') as f: | |||||
| f.write(s) | |||||
| test(config, test_loader, trial_number=trial.number) | |||||
| opencv-python==4.5.1.48 | |||||
| optuna==2.8.0 | |||||
| pandas==1.1.5 | |||||
| Pillow==8.2.0 | |||||
| pytorch-lightning==1.2.10 | |||||
| pytorch-model-summary==0.1.2 | |||||
| pytorchtools==0.0.2 | |||||
| scikit-image==0.17.2 | |||||
| scikit-learn==0.24.2 | |||||
| scipy==1.5.4 | |||||
| seaborn==0.11.1 | |||||
| torch==1.8.1 | |||||
| torchmetrics==0.2.0 | |||||
| torchsummary==1.5.1 | |||||
| torchvision==0.9.1 | |||||
| tqdm==4.60.0 | |||||
| transformers==4.5.1 |
| import random | |||||
| import numpy as np | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from data_loaders import make_dfs, build_loaders | |||||
| from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca | |||||
| from learner import batch_constructor | |||||
| from model import FakeNewsModel | |||||
| def test(config, test_loader, trial_number=None): | |||||
| if trial_number: | |||||
| try: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||||
| except: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||||
| else: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||||
| try: | |||||
| parameters = checkpoint['parameters'] | |||||
| config.assign_hyperparameters(parameters) | |||||
| except: | |||||
| pass | |||||
| model = FakeNewsModel(config).to(config.device) | |||||
| try: | |||||
| model.load_state_dict(checkpoint['model_state_dict']) | |||||
| except: | |||||
| model.load_state_dict(checkpoint) | |||||
| model.eval() | |||||
| torch.manual_seed(27) | |||||
| random.seed(27) | |||||
| np.random.seed(27) | |||||
| image_features = [] | |||||
| text_features = [] | |||||
| multimodal_features = [] | |||||
| concat_features = [] | |||||
| targets = [] | |||||
| predictions = [] | |||||
| scores = [] | |||||
| tqdm_object = tqdm(test_loader, total=len(test_loader)) | |||||
| for i, batch in enumerate(tqdm_object): | |||||
| batch = batch_constructor(config, batch) | |||||
| with torch.no_grad(): | |||||
| output, score = model(batch) | |||||
| prediction = output.detach() | |||||
| predictions.append(prediction) | |||||
| score = score.detach() | |||||
| scores.append(score) | |||||
| target = batch['label'].detach() | |||||
| targets.append(target) | |||||
| image_feature = model.image_embeddings.detach() | |||||
| image_features.append(image_feature) | |||||
| text_feature = model.text_embeddings.detach() | |||||
| text_features.append(text_feature) | |||||
| multimodal_feature = model.multimodal_embeddings.detach() | |||||
| multimodal_features.append(multimodal_feature) | |||||
| concat_feature = model.classifier.embeddings.detach() | |||||
| concat_features.append(concat_feature) | |||||
| # config.writer.add_graph(model, input_to_model=batch, verbose=True) | |||||
| s = '' | |||||
| s += report_per_class(targets, predictions) + '\n' | |||||
| s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n' | |||||
| with open(config.output_path + '/results.txt', 'w') as f: | |||||
| f.write(s) | |||||
| roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") | |||||
| precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") | |||||
| # saving_in_tensorboard(config, image_features, targets, 'image_features') | |||||
| plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png') | |||||
| plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png') | |||||
| # saving_in_tensorboard(config, text_features, targets, 'text_features') | |||||
| plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png') | |||||
| plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png') | |||||
| # | |||||
| # saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features') | |||||
| plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png') | |||||
| plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png') | |||||
| # saving_in_tensorboard(config, concat_features, targets, 'concat_features') | |||||
| plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png') | |||||
| plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png') | |||||
| config_parameters = str(config) | |||||
| with open(config.output_path + '/parameters.txt', 'w') as f: | |||||
| f.write(config_parameters) | |||||
| print(config) | |||||
| def test_main(config, trial_number=None): | |||||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
| test_loader = build_loaders(config, test_df, mode="test") | |||||
| test(config, test_loader, trial_number) |
| from torch import nn | |||||
| from transformers import BertModel, BertConfig, BertTokenizer, \ | |||||
| BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||||
| XLNetModel, XLNetConfig, XLNetTokenizer | |||||
| class TextEncoder(nn.Module): | |||||
| def __init__(self, config): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| self.model = get_text_model(config) | |||||
| for p in self.model.parameters(): | |||||
| p.requires_grad = self.config.trainable | |||||
| self.target_token_idx = 0 | |||||
| self.text_encoder_embedding = dict() | |||||
| def forward(self, ids, input_ids, attention_mask): | |||||
| output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |||||
| last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] | |||||
| # for i, id in enumerate(ids): | |||||
| # id = int(id.detach().cpu().numpy()) | |||||
| # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() | |||||
| return last_hidden_state | |||||
| def get_text_model(config): | |||||
| if 'bigbird' in config.text_encoder_model: | |||||
| if config.pretrained: | |||||
| model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||||
| else: | |||||
| model = BigBirdModel(config=BigBirdConfig()) | |||||
| elif 'xlnet' in config.text_encoder_model: | |||||
| if config.pretrained: | |||||
| model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
| output_hidden_states=True, return_dict=True) | |||||
| else: | |||||
| model = XLNetModel(config=XLNetConfig()) | |||||
| else: | |||||
| if config.pretrained: | |||||
| model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
| output_hidden_states=True, return_dict=True) | |||||
| else: | |||||
| model = BertModel(config=BertConfig()) | |||||
| return model | |||||
| from data_loaders import build_loaders, make_dfs | |||||
| from learner import supervised_train | |||||
| from test_main import test | |||||
| def torch_main(config): | |||||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
| train_loader = build_loaders(config, train_df, mode="train") | |||||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||||
| test_loader = build_loaders(config, test_df, mode="test") | |||||
| supervised_train(config, train_loader, validation_loader) | |||||
| test(config, test_loader) | |||||
| import numpy as np | |||||
| import torch | |||||
| class AvgMeter: | |||||
| def __init__(self, name="Metric"): | |||||
| self.name = name | |||||
| self.reset() | |||||
| def reset(self): | |||||
| self.avg, self.sum, self.count = [0] * 3 | |||||
| def update(self, val, count=1): | |||||
| self.count += count | |||||
| self.sum += val * count | |||||
| self.avg = self.sum / self.count | |||||
| def __repr__(self): | |||||
| text = f"{self.name}: {self.avg:.4f}" | |||||
| return text | |||||
| def print_lr(optimizer): | |||||
| for param_group in optimizer.param_groups: | |||||
| print(param_group['name'], param_group['lr']) | |||||
| class CheckpointSaving: | |||||
| def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print): | |||||
| self.best_score = None | |||||
| self.val_acc_max = 0 | |||||
| self.path = path | |||||
| self.verbose = verbose | |||||
| self.trace_func = trace_func | |||||
| def __call__(self, val_acc, model): | |||||
| if self.best_score is None: | |||||
| self.best_score = val_acc | |||||
| self.save_checkpoint(val_acc, model) | |||||
| elif val_acc > self.best_score: | |||||
| self.best_score = val_acc | |||||
| self.save_checkpoint(val_acc, model) | |||||
| def save_checkpoint(self, val_acc, model): | |||||
| if self.verbose: | |||||
| self.trace_func( | |||||
| f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...') | |||||
| torch.save(model.state_dict(), self.path) | |||||
| self.val_acc_max = val_acc | |||||
| class EarlyStopping: | |||||
| def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print): | |||||
| self.patience = patience | |||||
| self.verbose = verbose | |||||
| self.counter = 0 | |||||
| self.best_score = None | |||||
| self.early_stop = False | |||||
| self.val_loss_min = np.Inf | |||||
| self.delta = delta | |||||
| self.path = path | |||||
| self.trace_func = trace_func | |||||
| def __call__(self, val_loss, model): | |||||
| score = -val_loss | |||||
| if self.best_score is None: | |||||
| self.best_score = score | |||||
| if self.verbose: | |||||
| self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||||
| self.val_loss_min = val_loss | |||||
| # self.save_checkpoint(val_loss, model) | |||||
| elif score < self.best_score + self.delta: | |||||
| self.counter += 1 | |||||
| self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') | |||||
| if self.counter >= self.patience: | |||||
| self.early_stop = True | |||||
| else: | |||||
| self.best_score = score | |||||
| if self.verbose: | |||||
| self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') | |||||
| # self.save_checkpoint(val_loss, model) | |||||
| self.val_loss_min = val_loss | |||||
| self.counter = 0 | |||||
| # def save_checkpoint(self, val_loss, model): | |||||
| # if self.verbose: | |||||
| # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||||
| # torch.save(model.state_dict(), self.path) | |||||
| # self.val_loss_min = val_loss |