| size = 224 | size = 224 | ||||
| num_projection_layers = 1 | num_projection_layers = 1 | ||||
| projection_size = 256 | |||||
| projection_size = 64 | |||||
| dropout = 0.3 | dropout = 0.3 | ||||
| hidden_size = 256 | |||||
| hidden_size = 128 | |||||
| num_region = 64 # 16 #TODO | num_region = 64 # 16 #TODO | ||||
| region_size = projection_size // num_region | region_size = projection_size // num_region | ||||
| import torch | import torch | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| import albumentations as A | |||||
| from transformers import ViTFeatureExtractor | |||||
| from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
| def get_transforms(config): | |||||
| return A.Compose( | |||||
| [ | |||||
| A.Resize(config.size, config.size, always_apply=True), | |||||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
| ] | |||||
| ) | |||||
| def get_tokenizer(config): | |||||
| if 'bigbird' in config.text_encoder_model: | |||||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
| elif 'xlnet' in config.text_encoder_model: | |||||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
| else: | |||||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
| return tokenizer | |||||
| class DatasetLoader(Dataset): | class DatasetLoader(Dataset): | ||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| def __init__(self, config, dataframe, mode): | |||||
| self.config = config | self.config = config | ||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer( | |||||
| list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||||
| ) | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | self.mode = mode | ||||
| if mode == 'lime': | |||||
| self.image_filenames = [dataframe["image"],] | |||||
| self.text = [dataframe["text"],] | |||||
| self.labels = [dataframe["label"],] | |||||
| else: | |||||
| self.image_filenames = dataframe["image"].values | |||||
| self.text = list(dataframe["text"].values) | |||||
| self.labels = dataframe["label"].values | |||||
| tokenizer = get_tokenizer(config) | |||||
| self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') | |||||
| if 'resnet' in config.image_model_name: | |||||
| self.transforms = get_transforms(config) | |||||
| else: | |||||
| self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
| def set_text(self, idx): | def set_text(self, idx): | ||||
| item = { | item = { | ||||
| item['id'] = idx | item['id'] = idx | ||||
| return item | return item | ||||
| def set_image(self, image, item): | |||||
| def set_image(self, image): | |||||
| if 'resnet' in self.config.image_model_name: | if 'resnet' in self.config.image_model_name: | ||||
| image = self.transforms(image=image)['image'] | image = self.transforms(image=image)['image'] | ||||
| item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||||
| return {'image': torch.as_tensor(image).reshape((3, 224, 224))} | |||||
| else: | else: | ||||
| image = self.transforms(images=image, return_tensors='pt') | image = self.transforms(images=image, return_tensors='pt') | ||||
| image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | ||||
| item['image'] = image.reshape((3, 224, 224)) | |||||
| return {'image': image.reshape((3, 224, 224))} |
| name = 'twitter' | name = 'twitter' | ||||
| DatasetLoader = TwitterDatasetLoader | DatasetLoader = TwitterDatasetLoader | ||||
| data_path = '/twitter/' | |||||
| output_path = '' | |||||
| data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/' | |||||
| # data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/' | |||||
| output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||||
| # output_path = '' | |||||
| train_image_path = data_path + 'images_train/' | train_image_path = data_path + 'images_train/' | ||||
| validation_image_path = data_path + 'images_validation/' | |||||
| validation_image_path = data_path + 'images_test/' | |||||
| test_image_path = data_path + 'images_test/' | test_image_path = data_path + 'images_test/' | ||||
| train_text_path = data_path + 'twitter_train_translated.csv' | train_text_path = data_path + 'twitter_train_translated.csv' | ||||
| validation_text_path = data_path + 'twitter_validation_translated.csv' | |||||
| validation_text_path = data_path + 'twitter_test_translated.csv' | |||||
| test_text_path = data_path + 'twitter_test_translated.csv' | test_text_path = data_path + 'twitter_test_translated.csv' | ||||
| batch_size = 128 | batch_size = 128 | ||||
| attention_weight_decay = 0.001 | attention_weight_decay = 0.001 | ||||
| classification_weight_decay = 0.001 | classification_weight_decay = 0.001 | ||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
| device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||||
| image_model_name = 'vit-base-patch16-224' | |||||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
| image_embedding = 768 | image_embedding = 768 | ||||
| text_encoder_model = "bert-base-uncased" | |||||
| text_tokenizer = "bert-base-uncased" | |||||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
| text_embedding = 768 | text_embedding = 768 | ||||
| max_length = 32 | max_length = 32 | ||||
| class TwitterDatasetLoader(DatasetLoader): | class TwitterDatasetLoader(DatasetLoader): | ||||
| def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
| item = self.set_text(idx) | item = self.set_text(idx) | ||||
| item.update(self.set_img(idx)) | |||||
| return item | |||||
| def set_img(self, idx): | |||||
| try: | try: | ||||
| if self.mode == 'train': | if self.mode == 'train': | ||||
| image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | ||||
| except: | except: | ||||
| image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | ||||
| image = asarray(image) | image = asarray(image) | ||||
| self.set_image(image, item) | |||||
| return item | |||||
| return self.set_image(image) | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.text) | return len(self.text) | ||||
| class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
| self.config = config | |||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | |||||
| if mode == 'train': | |||||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| else: | |||||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| def __getitem__(self, idx): | |||||
| item = dict() | |||||
| item['id'] = idx | |||||
| item['image_embedding'] = self.image_embedding[idx] | |||||
| item['text_embedding'] = self.text_embedding[idx] | |||||
| item['label'] = self.labels[idx] | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) |
| name = 'weibo' | name = 'weibo' | ||||
| DatasetLoader = WeiboDatasetLoader | DatasetLoader = WeiboDatasetLoader | ||||
| data_path = 'weibo/' | |||||
| output_path = '' | |||||
| data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/' | |||||
| # data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/' | |||||
| output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||||
| # output_path = '' | |||||
| rumor_image_path = data_path + 'rumor_images/' | rumor_image_path = data_path + 'rumor_images/' | ||||
| nonrumor_image_path = data_path + 'nonrumor_images/' | nonrumor_image_path = data_path + 'nonrumor_images/' | ||||
| train_text_path = data_path + 'weibo_train.csv' | train_text_path = data_path + 'weibo_train.csv' | ||||
| validation_text_path = data_path + 'weibo_validation.csv' | |||||
| validation_text_path = data_path + 'weibo_train.csv' | |||||
| test_text_path = data_path + 'weibo_test.csv' | test_text_path = data_path + 'weibo_test.csv' | ||||
| batch_size = 128 | |||||
| batch_size = 64 | |||||
| epochs = 100 | epochs = 100 | ||||
| num_workers = 2 | |||||
| num_workers = 4 | |||||
| head_lr = 1e-03 | head_lr = 1e-03 | ||||
| image_encoder_lr = 1e-02 | image_encoder_lr = 1e-02 | ||||
| text_encoder_lr = 1e-05 | text_encoder_lr = 1e-05 | ||||
| weight_decay = 0.001 | weight_decay = 0.001 | ||||
| classification_lr = 1e-02 | classification_lr = 1e-02 | ||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
| device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||||
| image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||||
| image_embedding = 768 # 2048 | |||||
| num_img_region = 64 # TODO | |||||
| text_encoder_model = "bert-base-chinese" | |||||
| text_tokenizer = "bert-base-chinese" | |||||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
| image_embedding = 768 | |||||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
| text_embedding = 768 | text_embedding = 768 | ||||
| max_length = 200 | max_length = 200 | ||||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | ||||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | ||||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | ||||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | ||||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
| self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
| self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
| class WeiboDatasetLoader(DatasetLoader): | class WeiboDatasetLoader(DatasetLoader): | ||||
| def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
| item = self.set_text(idx) | item = self.set_text(idx) | ||||
| item.update(self.set_img(idx)) | |||||
| return item | |||||
| def set_img(self, idx): | |||||
| if self.labels[idx] == 1: | if self.labels[idx] == 1: | ||||
| try: | try: | ||||
| image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | ||||
| except: | except: | ||||
| image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | ||||
| image = asarray(image) | image = asarray(image) | ||||
| self.set_image(image, item) | |||||
| return item | |||||
| return self.set_image(image) | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.text) | return len(self.text) | ||||
| class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
| self.config = config | |||||
| self.image_filenames = image_filenames | |||||
| self.text = list(texts) | |||||
| self.encoded_text = tokenizer | |||||
| self.labels = labels | |||||
| self.transforms = transforms | |||||
| self.mode = mode | |||||
| if mode == 'train': | |||||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| else: | |||||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
| self.image_embedding = pickle.load(handle) | |||||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
| self.text_embedding = pickle.load(handle) | |||||
| def __getitem__(self, idx): | |||||
| item = dict() | |||||
| item['id'] = idx | |||||
| item['image_embedding'] = self.image_embedding[idx] | |||||
| item['text_embedding'] = self.text_embedding[idx] | |||||
| item['label'] = self.labels[idx] | |||||
| return item | |||||
| def __len__(self): | |||||
| return len(self.text) |
| import albumentations as A | |||||
| import numpy as np | import numpy as np | ||||
| import pandas as pd | import pandas as pd | ||||
| from sklearn.utils import compute_class_weight | from sklearn.utils import compute_class_weight | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
| def get_transforms(config, mode="train"): | |||||
| if mode == "train": | |||||
| return A.Compose( | |||||
| [ | |||||
| A.Resize(config.size, config.size, always_apply=True), | |||||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
| ] | |||||
| ) | |||||
| else: | |||||
| return A.Compose( | |||||
| [ | |||||
| A.Resize(config.size, config.size, always_apply=True), | |||||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
| ] | |||||
| ) | |||||
| def make_dfs(config): | def make_dfs(config): | ||||
| validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | ||||
| validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | ||||
| unlabeled_dataframe = None | |||||
| if config.has_unlabeled_data: | |||||
| unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||||
| unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||||
| return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||||
| def get_tokenizer(config): | |||||
| if 'bigbird' in config.text_encoder_model: | |||||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
| elif 'xlnet' in config.text_encoder_model: | |||||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
| else: | |||||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
| return tokenizer | |||||
| return train_dataframe, test_dataframe, validation_dataframe | |||||
| def build_loaders(config, dataframe, mode): | def build_loaders(config, dataframe, mode): | ||||
| if 'resnet' in config.image_model_name: | |||||
| transforms = get_transforms(config, mode=mode) | |||||
| else: | |||||
| transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
| dataset = config.DatasetLoader( | |||||
| config, | |||||
| dataframe["image"].values, | |||||
| dataframe["text"].values, | |||||
| dataframe["label"].values, | |||||
| tokenizer=get_tokenizer(config), | |||||
| transforms=transforms, | |||||
| mode=mode, | |||||
| ) | |||||
| dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode) | |||||
| if mode != 'train': | if mode != 'train': | ||||
| dataloader = DataLoader( | dataloader = DataLoader( | ||||
| dataset, | dataset, | ||||
| batch_size=config.batch_size, | |||||
| num_workers=config.num_workers, | |||||
| batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||||
| num_workers=config.num_workers // 2, | |||||
| pin_memory=False, | pin_memory=False, | ||||
| shuffle=False, | shuffle=False, | ||||
| ) | ) |
| import random | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| import torch | |||||
| from matplotlib import pyplot as plt | |||||
| from tqdm import tqdm | |||||
| from data_loaders import make_dfs, build_loaders | |||||
| from learner import batch_constructor | |||||
| from model import FakeNewsModel | |||||
| from lime.lime_image import LimeImageExplainer | |||||
| from skimage.segmentation import mark_boundaries | |||||
| from lime.lime_text import LimeTextExplainer | |||||
| from utils import make_directory | |||||
| def lime_(config, test_loader, model): | |||||
| for i, batch in enumerate(test_loader): | |||||
| batch = batch_constructor(config, batch) | |||||
| with torch.no_grad(): | |||||
| output, score = model(batch) | |||||
| score = score.detach().cpu().numpy() | |||||
| logit = model.logits.detach().cpu().numpy() | |||||
| return score, logit | |||||
| def lime_main(config, trial_number=None): | |||||
| _, test_df, _ = make_dfs(config, ) | |||||
| test_df = test_df[:1] | |||||
| if trial_number: | |||||
| try: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||||
| except: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||||
| else: | |||||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||||
| try: | |||||
| parameters = checkpoint['parameters'] | |||||
| config.assign_hyperparameters(parameters) | |||||
| except: | |||||
| pass | |||||
| model = FakeNewsModel(config).to(config.device) | |||||
| try: | |||||
| model.load_state_dict(checkpoint['model_state_dict']) | |||||
| except: | |||||
| model.load_state_dict(checkpoint) | |||||
| model.eval() | |||||
| torch.manual_seed(27) | |||||
| random.seed(27) | |||||
| np.random.seed(27) | |||||
| text_explainer = LimeTextExplainer(class_names=config.classes) | |||||
| image_explainer = LimeImageExplainer() | |||||
| make_directory(config.output_path + '/lime/') | |||||
| make_directory(config.output_path + '/lime/score/') | |||||
| make_directory(config.output_path + '/lime/logit/') | |||||
| make_directory(config.output_path + '/lime/text/') | |||||
| make_directory(config.output_path + '/lime/image/') | |||||
| def text_predict_proba(text): | |||||
| scores = [] | |||||
| print(len(text)) | |||||
| for i in text: | |||||
| row['text'] = i | |||||
| test_loader = build_loaders(config, row, mode="lime") | |||||
| score, _ = lime_(config, test_loader, model) | |||||
| scores.append(score) | |||||
| return np.array(scores) | |||||
| def image_predict_proba(image): | |||||
| scores = [] | |||||
| print(len(image)) | |||||
| for i in image: | |||||
| test_loader = build_loaders(config, row, mode="lime") | |||||
| test_loader['image'] = i.reshape((3, 224, 224)) | |||||
| score, _ = lime_(config, test_loader, model) | |||||
| scores.append(score) | |||||
| return np.array(scores) | |||||
| for i, row in test_df.iterrows(): | |||||
| test_loader = build_loaders(config, row, mode="lime") | |||||
| score, logit = lime_(config, test_loader, model) | |||||
| np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||||
| np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||||
| text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||||
| text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||||
| print('text', i, 'finished') | |||||
| data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||||
| img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||||
| top_labels=2, hide_color=0, num_samples=1) | |||||
| temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||||
| hide_rest=False) | |||||
| img_boundry = mark_boundaries(temp / 255.0, mask) | |||||
| plt.imshow(img_boundry) | |||||
| plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||||
| print('image', i, 'finished') |
| import argparse | import argparse | ||||
| from lime_main import lime_main | |||||
| from optuna_main import optuna_main | from optuna_main import optuna_main | ||||
| from test_main import test_main | from test_main import test_main | ||||
| from torch_main import torch_main | from torch_main import torch_main | ||||
| from data.weibo.config import WeiboConfig | from data.weibo.config import WeiboConfig | ||||
| from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||
| from utils import make_directory | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| import os | import os | ||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument('--data', type=str, required=True) | parser.add_argument('--data', type=str, required=True) | ||||
| parser.add_argument('--use_optuna', type=int, required=False) | parser.add_argument('--use_optuna', type=int, required=False) | ||||
| parser.add_argument('--use_lime', type=int, required=False) | |||||
| parser.add_argument('--just_test', type=int, required=False) | parser.add_argument('--just_test', type=int, required=False) | ||||
| parser.add_argument('--batch', type=int, required=False) | parser.add_argument('--batch', type=int, required=False) | ||||
| parser.add_argument('--epoch', type=int, required=False) | parser.add_argument('--epoch', type=int, required=False) | ||||
| config.output_path = str(args.extra) | config.output_path = str(args.extra) | ||||
| use_optuna = False | use_optuna = False | ||||
| try: | |||||
| os.mkdir(config.output_path) | |||||
| except OSError: | |||||
| print("Creation of the directory failed") | |||||
| else: | |||||
| print("Successfully created the directory") | |||||
| make_directory(config.output_path) | |||||
| config.writer = SummaryWriter(config.output_path) | config.writer = SummaryWriter(config.output_path) | ||||
| optuna_main(config, args.use_optuna) | optuna_main(config, args.use_optuna) | ||||
| elif args.just_test: | elif args.just_test: | ||||
| test_main(config, args.just_test) | test_main(config, args.just_test) | ||||
| elif args.use_lime: | |||||
| lime_main(config, args.use_lime) | |||||
| else: | else: | ||||
| torch_main(config) | torch_main(config) | ||||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | ||||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | ||||
| self.classifier = Classifier(config) | self.classifier = Classifier(config) | ||||
| self.temperature = config.temperature | |||||
| class_weights = torch.FloatTensor(config.class_weights) | class_weights = torch.FloatTensor(config.class_weights) | ||||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | ||||
| self.text_embeddings = self.text_projection(text_features) | self.text_embeddings = self.text_projection(text_features) | ||||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | ||||
| score = self.classifier(self.multimodal_embeddings) | |||||
| probs, output = torch.max(score.data, dim=1) | |||||
| return output, score | |||||
| class FakeNewsEmbeddingModel(nn.Module): | |||||
| def __init__( | |||||
| self, config | |||||
| ): | |||||
| super().__init__() | |||||
| self.config = config | |||||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
| self.classifier = Classifier(config) | |||||
| self.temperature = config.temperature | |||||
| class_weights = torch.FloatTensor(config.class_weights) | |||||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
| def forward(self, batch): | |||||
| self.image_embeddings = self.image_projection(batch['image_embedding']) | |||||
| self.text_embeddings = self.text_projection(batch['text_embedding']) | |||||
| self.logits = (self.text_embeddings @ self.image_embeddings.T) | |||||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
| score = self.classifier(self.multimodal_embeddings) | score = self.classifier(self.multimodal_embeddings) | ||||
| probs, output = torch.max(score.data, dim=1) | probs, output = torch.max(score.data, dim=1) | ||||
| return output, score | return output, score | ||||
| def calculate_loss(model, score, label): | def calculate_loss(model, score, label): | ||||
| similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | ||||
| # fake = (label == 1).nonzero() | # fake = (label == 1).nonzero() | ||||
| def calculate_similarity_loss(config, image_embeddings, text_embeddings): | def calculate_similarity_loss(config, image_embeddings, text_embeddings): | ||||
| # Calculating the Loss | # Calculating the Loss | ||||
| logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||||
| logits = (text_embeddings @ image_embeddings.T) | |||||
| images_similarity = image_embeddings @ image_embeddings.T | images_similarity = image_embeddings @ image_embeddings.T | ||||
| texts_similarity = text_embeddings @ text_embeddings.T | texts_similarity = text_embeddings @ text_embeddings.T | ||||
| targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||||
| targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1) | |||||
| texts_loss = cross_entropy(logits, targets, reduction='none') | texts_loss = cross_entropy(logits, targets, reduction='none') | ||||
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | images_loss = cross_entropy(logits.T, targets.T, reduction='none') | ||||
| loss = (images_loss + texts_loss) / 2.0 | loss = (images_loss + texts_loss) / 2.0 |
| def optuna_main(config, n_trials=100): | def optuna_main(config, n_trials=100): | ||||
| train_df, test_df, validation_df, _ = make_dfs(config) | |||||
| train_df, test_df, validation_df = make_dfs(config) | |||||
| train_loader = build_loaders(config, train_df, mode="train") | train_loader = build_loaders(config, train_df, mode="train") | ||||
| validation_loader = build_loaders(config, validation_df, mode="validation") | validation_loader = build_loaders(config, validation_df, mode="validation") | ||||
| test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") | ||||
| trial = study.best_trial | trial = study.best_trial | ||||
| print(' Number: ', trial.number) | print(' Number: ', trial.number) | ||||
| print(" Value: ", trial.value) | print(" Value: ", trial.value) | ||||
| s += 'number: ' + str(trial.number) + '\n' | |||||
| s += 'value: ' + str(trial.value) + '\n' | s += 'value: ' + str(trial.value) + '\n' | ||||
| print(" Params: ") | print(" Params: ") |
| def test_main(config, trial_number=None): | def test_main(config, trial_number=None): | ||||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
| train_df, test_df, validation_df = make_dfs(config, ) | |||||
| test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") | ||||
| test(config, test_loader, trial_number) | test(config, test_loader, trial_number) |
| def torch_main(config): | def torch_main(config): | ||||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
| train_df, test_df, validation_df = make_dfs(config, ) | |||||
| train_loader = build_loaders(config, train_df, mode="train") | train_loader = build_loaders(config, train_df, mode="train") | ||||
| validation_loader = build_loaders(config, validation_df, mode="validation") | validation_loader = build_loaders(config, validation_df, mode="validation") | ||||
| test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") |
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import os | |||||
| class AvgMeter: | class AvgMeter: | ||||
| def __init__(self, name="Metric"): | def __init__(self, name="Metric"): | ||||
| # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | ||||
| # torch.save(model.state_dict(), self.path) | # torch.save(model.state_dict(), self.path) | ||||
| # self.val_loss_min = val_loss | # self.val_loss_min = val_loss | ||||
| def make_directory(path): | |||||
| try: | |||||
| os.mkdir(path) | |||||
| except OSError: | |||||
| print("Creation of the directory failed") | |||||
| else: | |||||
| print("Successfully created the directory") |