 FaezeGhorbanpour
					
					3 years ago
						FaezeGhorbanpour
					
					3 years ago
				| @@ -59,9 +59,9 @@ class Config: | |||
| size = 224 | |||
| num_projection_layers = 1 | |||
| projection_size = 256 | |||
| projection_size = 64 | |||
| dropout = 0.3 | |||
| hidden_size = 256 | |||
| hidden_size = 128 | |||
| num_region = 64 # 16 #TODO | |||
| region_size = projection_size // num_region | |||
| @@ -1,17 +1,48 @@ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| import albumentations as A | |||
| from transformers import ViTFeatureExtractor | |||
| from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
| def get_transforms(config): | |||
| return A.Compose( | |||
| [ | |||
| A.Resize(config.size, config.size, always_apply=True), | |||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
| ] | |||
| ) | |||
| def get_tokenizer(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
| else: | |||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
| return tokenizer | |||
| class DatasetLoader(Dataset): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| def __init__(self, config, dataframe, mode): | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer( | |||
| list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||
| ) | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| if mode == 'lime': | |||
| self.image_filenames = [dataframe["image"],] | |||
| self.text = [dataframe["text"],] | |||
| self.labels = [dataframe["label"],] | |||
| else: | |||
| self.image_filenames = dataframe["image"].values | |||
| self.text = list(dataframe["text"].values) | |||
| self.labels = dataframe["label"].values | |||
| tokenizer = get_tokenizer(config) | |||
| self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') | |||
| if 'resnet' in config.image_model_name: | |||
| self.transforms = get_transforms(config) | |||
| else: | |||
| self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
| def set_text(self, idx): | |||
| item = { | |||
| @@ -23,11 +54,11 @@ class DatasetLoader(Dataset): | |||
| item['id'] = idx | |||
| return item | |||
| def set_image(self, image, item): | |||
| def set_image(self, image): | |||
| if 'resnet' in self.config.image_model_name: | |||
| image = self.transforms(image=image)['image'] | |||
| item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||
| return {'image': torch.as_tensor(image).reshape((3, 224, 224))} | |||
| else: | |||
| image = self.transforms(images=image, return_tensors='pt') | |||
| image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||
| item['image'] = image.reshape((3, 224, 224)) | |||
| return {'image': image.reshape((3, 224, 224))} | |||
| @@ -8,15 +8,17 @@ class TwitterConfig(Config): | |||
| name = 'twitter' | |||
| DatasetLoader = TwitterDatasetLoader | |||
| data_path = '/twitter/' | |||
| output_path = '' | |||
| data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/' | |||
| # data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/' | |||
| output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||
| # output_path = '' | |||
| train_image_path = data_path + 'images_train/' | |||
| validation_image_path = data_path + 'images_validation/' | |||
| validation_image_path = data_path + 'images_test/' | |||
| test_image_path = data_path + 'images_test/' | |||
| train_text_path = data_path + 'twitter_train_translated.csv' | |||
| validation_text_path = data_path + 'twitter_validation_translated.csv' | |||
| validation_text_path = data_path + 'twitter_test_translated.csv' | |||
| test_text_path = data_path + 'twitter_test_translated.csv' | |||
| batch_size = 128 | |||
| @@ -32,12 +34,14 @@ class TwitterConfig(Config): | |||
| attention_weight_decay = 0.001 | |||
| classification_weight_decay = 0.001 | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||
| image_model_name = 'vit-base-patch16-224' | |||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
| image_embedding = 768 | |||
| text_encoder_model = "bert-base-uncased" | |||
| text_tokenizer = "bert-base-uncased" | |||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_embedding = 768 | |||
| max_length = 32 | |||
| @@ -10,7 +10,10 @@ from data.data_loader import DatasetLoader | |||
| class TwitterDatasetLoader(DatasetLoader): | |||
| def __getitem__(self, idx): | |||
| item = self.set_text(idx) | |||
| item.update(self.set_img(idx)) | |||
| return item | |||
| def set_img(self, idx): | |||
| try: | |||
| if self.mode == 'train': | |||
| image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||
| @@ -22,44 +25,8 @@ class TwitterDatasetLoader(DatasetLoader): | |||
| except: | |||
| image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
| image = asarray(image) | |||
| self.set_image(image, item) | |||
| return item | |||
| return self.set_image(image) | |||
| def __len__(self): | |||
| return len(self.text) | |||
| class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| if mode == 'train': | |||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| else: | |||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| def __getitem__(self, idx): | |||
| item = dict() | |||
| item['id'] = idx | |||
| item['image_embedding'] = self.image_embedding[idx] | |||
| item['text_embedding'] = self.text_embedding[idx] | |||
| item['label'] = self.labels[idx] | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| @@ -9,32 +9,35 @@ class WeiboConfig(Config): | |||
| name = 'weibo' | |||
| DatasetLoader = WeiboDatasetLoader | |||
| data_path = 'weibo/' | |||
| output_path = '' | |||
| data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/' | |||
| # data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/' | |||
| output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||
| # output_path = '' | |||
| rumor_image_path = data_path + 'rumor_images/' | |||
| nonrumor_image_path = data_path + 'nonrumor_images/' | |||
| train_text_path = data_path + 'weibo_train.csv' | |||
| validation_text_path = data_path + 'weibo_validation.csv' | |||
| validation_text_path = data_path + 'weibo_train.csv' | |||
| test_text_path = data_path + 'weibo_test.csv' | |||
| batch_size = 128 | |||
| batch_size = 64 | |||
| epochs = 100 | |||
| num_workers = 2 | |||
| num_workers = 4 | |||
| head_lr = 1e-03 | |||
| image_encoder_lr = 1e-02 | |||
| text_encoder_lr = 1e-05 | |||
| weight_decay = 0.001 | |||
| classification_lr = 1e-02 | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||
| image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||
| image_embedding = 768 # 2048 | |||
| num_img_region = 64 # TODO | |||
| text_encoder_model = "bert-base-chinese" | |||
| text_tokenizer = "bert-base-chinese" | |||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
| image_embedding = 768 | |||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_embedding = 768 | |||
| max_length = 200 | |||
| @@ -53,10 +56,9 @@ class WeiboConfig(Config): | |||
| self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
| self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
| # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
| self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
| self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
| # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
| # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
| self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
| self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
| @@ -11,7 +11,10 @@ from data.data_loader import DatasetLoader | |||
| class WeiboDatasetLoader(DatasetLoader): | |||
| def __getitem__(self, idx): | |||
| item = self.set_text(idx) | |||
| item.update(self.set_img(idx)) | |||
| return item | |||
| def set_img(self, idx): | |||
| if self.labels[idx] == 1: | |||
| try: | |||
| image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||
| @@ -24,43 +27,8 @@ class WeiboDatasetLoader(DatasetLoader): | |||
| except: | |||
| image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
| image = asarray(image) | |||
| self.set_image(image, item) | |||
| return item | |||
| return self.set_image(image) | |||
| def __len__(self): | |||
| return len(self.text) | |||
| class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||
| def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
| super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
| self.config = config | |||
| self.image_filenames = image_filenames | |||
| self.text = list(texts) | |||
| self.encoded_text = tokenizer | |||
| self.labels = labels | |||
| self.transforms = transforms | |||
| self.mode = mode | |||
| if mode == 'train': | |||
| with open(config.train_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.train_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| else: | |||
| with open(config.validation_image_embedding_path, 'rb') as handle: | |||
| self.image_embedding = pickle.load(handle) | |||
| with open(config.validation_text_embedding_path, 'rb') as handle: | |||
| self.text_embedding = pickle.load(handle) | |||
| def __getitem__(self, idx): | |||
| item = dict() | |||
| item['id'] = idx | |||
| item['image_embedding'] = self.image_embedding[idx] | |||
| item['text_embedding'] = self.text_embedding[idx] | |||
| item['label'] = self.labels[idx] | |||
| return item | |||
| def __len__(self): | |||
| return len(self.text) | |||
| @@ -1,27 +1,8 @@ | |||
| import albumentations as A | |||
| import numpy as np | |||
| import pandas as pd | |||
| from sklearn.utils import compute_class_weight | |||
| from torch.utils.data import DataLoader | |||
| from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
| def get_transforms(config, mode="train"): | |||
| if mode == "train": | |||
| return A.Compose( | |||
| [ | |||
| A.Resize(config.size, config.size, always_apply=True), | |||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
| ] | |||
| ) | |||
| else: | |||
| return A.Compose( | |||
| [ | |||
| A.Resize(config.size, config.size, always_apply=True), | |||
| A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
| ] | |||
| ) | |||
| def make_dfs(config): | |||
| @@ -51,44 +32,16 @@ def make_dfs(config): | |||
| validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||
| validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
| unlabeled_dataframe = None | |||
| if config.has_unlabeled_data: | |||
| unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||
| unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||
| return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||
| def get_tokenizer(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
| else: | |||
| tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
| return tokenizer | |||
| return train_dataframe, test_dataframe, validation_dataframe | |||
| def build_loaders(config, dataframe, mode): | |||
| if 'resnet' in config.image_model_name: | |||
| transforms = get_transforms(config, mode=mode) | |||
| else: | |||
| transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
| dataset = config.DatasetLoader( | |||
| config, | |||
| dataframe["image"].values, | |||
| dataframe["text"].values, | |||
| dataframe["label"].values, | |||
| tokenizer=get_tokenizer(config), | |||
| transforms=transforms, | |||
| mode=mode, | |||
| ) | |||
| dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode) | |||
| if mode != 'train': | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| batch_size=config.batch_size, | |||
| num_workers=config.num_workers, | |||
| batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||
| num_workers=config.num_workers // 2, | |||
| pin_memory=False, | |||
| shuffle=False, | |||
| ) | |||
| @@ -0,0 +1,110 @@ | |||
| import random | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch | |||
| from matplotlib import pyplot as plt | |||
| from tqdm import tqdm | |||
| from data_loaders import make_dfs, build_loaders | |||
| from learner import batch_constructor | |||
| from model import FakeNewsModel | |||
| from lime.lime_image import LimeImageExplainer | |||
| from skimage.segmentation import mark_boundaries | |||
| from lime.lime_text import LimeTextExplainer | |||
| from utils import make_directory | |||
| def lime_(config, test_loader, model): | |||
| for i, batch in enumerate(test_loader): | |||
| batch = batch_constructor(config, batch) | |||
| with torch.no_grad(): | |||
| output, score = model(batch) | |||
| score = score.detach().cpu().numpy() | |||
| logit = model.logits.detach().cpu().numpy() | |||
| return score, logit | |||
| def lime_main(config, trial_number=None): | |||
| _, test_df, _ = make_dfs(config, ) | |||
| test_df = test_df[:1] | |||
| if trial_number: | |||
| try: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
| except: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
| else: | |||
| checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
| try: | |||
| parameters = checkpoint['parameters'] | |||
| config.assign_hyperparameters(parameters) | |||
| except: | |||
| pass | |||
| model = FakeNewsModel(config).to(config.device) | |||
| try: | |||
| model.load_state_dict(checkpoint['model_state_dict']) | |||
| except: | |||
| model.load_state_dict(checkpoint) | |||
| model.eval() | |||
| torch.manual_seed(27) | |||
| random.seed(27) | |||
| np.random.seed(27) | |||
| text_explainer = LimeTextExplainer(class_names=config.classes) | |||
| image_explainer = LimeImageExplainer() | |||
| make_directory(config.output_path + '/lime/') | |||
| make_directory(config.output_path + '/lime/score/') | |||
| make_directory(config.output_path + '/lime/logit/') | |||
| make_directory(config.output_path + '/lime/text/') | |||
| make_directory(config.output_path + '/lime/image/') | |||
| def text_predict_proba(text): | |||
| scores = [] | |||
| print(len(text)) | |||
| for i in text: | |||
| row['text'] = i | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| score, _ = lime_(config, test_loader, model) | |||
| scores.append(score) | |||
| return np.array(scores) | |||
| def image_predict_proba(image): | |||
| scores = [] | |||
| print(len(image)) | |||
| for i in image: | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| test_loader['image'] = i.reshape((3, 224, 224)) | |||
| score, _ = lime_(config, test_loader, model) | |||
| scores.append(score) | |||
| return np.array(scores) | |||
| for i, row in test_df.iterrows(): | |||
| test_loader = build_loaders(config, row, mode="lime") | |||
| score, logit = lime_(config, test_loader, model) | |||
| np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||
| np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||
| text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||
| text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||
| print('text', i, 'finished') | |||
| data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||
| img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||
| top_labels=2, hide_color=0, num_samples=1) | |||
| temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||
| hide_rest=False) | |||
| img_boundry = mark_boundaries(temp / 255.0, mask) | |||
| plt.imshow(img_boundry) | |||
| plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||
| print('image', i, 'finished') | |||
| @@ -1,5 +1,6 @@ | |||
| import argparse | |||
| from lime_main import lime_main | |||
| from optuna_main import optuna_main | |||
| from test_main import test_main | |||
| from torch_main import torch_main | |||
| @@ -7,6 +8,8 @@ from data.twitter.config import TwitterConfig | |||
| from data.weibo.config import WeiboConfig | |||
| from torch.utils.tensorboard import SummaryWriter | |||
| from utils import make_directory | |||
| if __name__ == '__main__': | |||
| import os | |||
| @@ -15,6 +18,7 @@ if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--data', type=str, required=True) | |||
| parser.add_argument('--use_optuna', type=int, required=False) | |||
| parser.add_argument('--use_lime', type=int, required=False) | |||
| parser.add_argument('--just_test', type=int, required=False) | |||
| parser.add_argument('--batch', type=int, required=False) | |||
| parser.add_argument('--epoch', type=int, required=False) | |||
| @@ -44,12 +48,7 @@ if __name__ == '__main__': | |||
| config.output_path = str(args.extra) | |||
| use_optuna = False | |||
| try: | |||
| os.mkdir(config.output_path) | |||
| except OSError: | |||
| print("Creation of the directory failed") | |||
| else: | |||
| print("Successfully created the directory") | |||
| make_directory(config.output_path) | |||
| config.writer = SummaryWriter(config.output_path) | |||
| @@ -57,6 +56,8 @@ if __name__ == '__main__': | |||
| optuna_main(config, args.use_optuna) | |||
| elif args.just_test: | |||
| test_main(config, args.just_test) | |||
| elif args.use_lime: | |||
| lime_main(config, args.use_lime) | |||
| else: | |||
| torch_main(config) | |||
| @@ -66,7 +66,6 @@ class FakeNewsModel(nn.Module): | |||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
| self.classifier = Classifier(config) | |||
| self.temperature = config.temperature | |||
| class_weights = torch.FloatTensor(config.class_weights) | |||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
| @@ -82,36 +81,15 @@ class FakeNewsModel(nn.Module): | |||
| self.text_embeddings = self.text_projection(text_features) | |||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
| score = self.classifier(self.multimodal_embeddings) | |||
| probs, output = torch.max(score.data, dim=1) | |||
| return output, score | |||
| class FakeNewsEmbeddingModel(nn.Module): | |||
| def __init__( | |||
| self, config | |||
| ): | |||
| super().__init__() | |||
| self.config = config | |||
| self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
| self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
| self.classifier = Classifier(config) | |||
| self.temperature = config.temperature | |||
| class_weights = torch.FloatTensor(config.class_weights) | |||
| self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
| def forward(self, batch): | |||
| self.image_embeddings = self.image_projection(batch['image_embedding']) | |||
| self.text_embeddings = self.text_projection(batch['text_embedding']) | |||
| self.logits = (self.text_embeddings @ self.image_embeddings.T) | |||
| self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
| score = self.classifier(self.multimodal_embeddings) | |||
| probs, output = torch.max(score.data, dim=1) | |||
| return output, score | |||
| def calculate_loss(model, score, label): | |||
| similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||
| # fake = (label == 1).nonzero() | |||
| @@ -124,10 +102,10 @@ def calculate_loss(model, score, label): | |||
| def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||
| # Calculating the Loss | |||
| logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||
| logits = (text_embeddings @ image_embeddings.T) | |||
| images_similarity = image_embeddings @ image_embeddings.T | |||
| texts_similarity = text_embeddings @ text_embeddings.T | |||
| targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||
| targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1) | |||
| texts_loss = cross_entropy(logits, targets, reduction='none') | |||
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||
| loss = (images_loss + texts_loss) / 2.0 | |||
| @@ -18,7 +18,7 @@ def objective(trial, config, train_loader, validation_loader): | |||
| def optuna_main(config, n_trials=100): | |||
| train_df, test_df, validation_df, _ = make_dfs(config) | |||
| train_df, test_df, validation_df = make_dfs(config) | |||
| train_loader = build_loaders(config, train_df, mode="train") | |||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| @@ -47,6 +47,7 @@ def optuna_main(config, n_trials=100): | |||
| trial = study.best_trial | |||
| print(' Number: ', trial.number) | |||
| print(" Value: ", trial.value) | |||
| s += 'number: ' + str(trial.number) + '\n' | |||
| s += 'value: ' + str(trial.value) + '\n' | |||
| print(" Params: ") | |||
| @@ -106,6 +106,6 @@ def test(config, test_loader, trial_number=None): | |||
| def test_main(config, trial_number=None): | |||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
| train_df, test_df, validation_df = make_dfs(config, ) | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| test(config, test_loader, trial_number) | |||
| @@ -5,7 +5,7 @@ from test_main import test | |||
| def torch_main(config): | |||
| train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
| train_df, test_df, validation_df = make_dfs(config, ) | |||
| train_loader = build_loaders(config, train_df, mode="train") | |||
| validation_loader = build_loaders(config, validation_df, mode="validation") | |||
| test_loader = build_loaders(config, test_df, mode="test") | |||
| @@ -1,6 +1,6 @@ | |||
| import numpy as np | |||
| import torch | |||
| import os | |||
| class AvgMeter: | |||
| def __init__(self, name="Metric"): | |||
| @@ -91,3 +91,12 @@ class EarlyStopping: | |||
| # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||
| # torch.save(model.state_dict(), self.path) | |||
| # self.val_loss_min = val_loss | |||
| def make_directory(path): | |||
| try: | |||
| os.mkdir(path) | |||
| except OSError: | |||
| print("Creation of the directory failed") | |||
| else: | |||
| print("Successfully created the directory") | |||