size = 224 | size = 224 | ||||
num_projection_layers = 1 | num_projection_layers = 1 | ||||
projection_size = 256 | |||||
projection_size = 64 | |||||
dropout = 0.3 | dropout = 0.3 | ||||
hidden_size = 256 | |||||
hidden_size = 128 | |||||
num_region = 64 # 16 #TODO | num_region = 64 # 16 #TODO | ||||
region_size = projection_size // num_region | region_size = projection_size // num_region | ||||
import torch | import torch | ||||
from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
import albumentations as A | |||||
from transformers import ViTFeatureExtractor | |||||
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
def get_transforms(config): | |||||
return A.Compose( | |||||
[ | |||||
A.Resize(config.size, config.size, always_apply=True), | |||||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
] | |||||
) | |||||
def get_tokenizer(config): | |||||
if 'bigbird' in config.text_encoder_model: | |||||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
elif 'xlnet' in config.text_encoder_model: | |||||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
else: | |||||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
return tokenizer | |||||
class DatasetLoader(Dataset): | class DatasetLoader(Dataset): | ||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
def __init__(self, config, dataframe, mode): | |||||
self.config = config | self.config = config | ||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer( | |||||
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||||
) | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | self.mode = mode | ||||
if mode == 'lime': | |||||
self.image_filenames = [dataframe["image"],] | |||||
self.text = [dataframe["text"],] | |||||
self.labels = [dataframe["label"],] | |||||
else: | |||||
self.image_filenames = dataframe["image"].values | |||||
self.text = list(dataframe["text"].values) | |||||
self.labels = dataframe["label"].values | |||||
tokenizer = get_tokenizer(config) | |||||
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') | |||||
if 'resnet' in config.image_model_name: | |||||
self.transforms = get_transforms(config) | |||||
else: | |||||
self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
def set_text(self, idx): | def set_text(self, idx): | ||||
item = { | item = { | ||||
item['id'] = idx | item['id'] = idx | ||||
return item | return item | ||||
def set_image(self, image, item): | |||||
def set_image(self, image): | |||||
if 'resnet' in self.config.image_model_name: | if 'resnet' in self.config.image_model_name: | ||||
image = self.transforms(image=image)['image'] | image = self.transforms(image=image)['image'] | ||||
item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||||
return {'image': torch.as_tensor(image).reshape((3, 224, 224))} | |||||
else: | else: | ||||
image = self.transforms(images=image, return_tensors='pt') | image = self.transforms(images=image, return_tensors='pt') | ||||
image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | ||||
item['image'] = image.reshape((3, 224, 224)) | |||||
return {'image': image.reshape((3, 224, 224))} |
name = 'twitter' | name = 'twitter' | ||||
DatasetLoader = TwitterDatasetLoader | DatasetLoader = TwitterDatasetLoader | ||||
data_path = '/twitter/' | |||||
output_path = '' | |||||
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/' | |||||
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/' | |||||
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||||
# output_path = '' | |||||
train_image_path = data_path + 'images_train/' | train_image_path = data_path + 'images_train/' | ||||
validation_image_path = data_path + 'images_validation/' | |||||
validation_image_path = data_path + 'images_test/' | |||||
test_image_path = data_path + 'images_test/' | test_image_path = data_path + 'images_test/' | ||||
train_text_path = data_path + 'twitter_train_translated.csv' | train_text_path = data_path + 'twitter_train_translated.csv' | ||||
validation_text_path = data_path + 'twitter_validation_translated.csv' | |||||
validation_text_path = data_path + 'twitter_test_translated.csv' | |||||
test_text_path = data_path + 'twitter_test_translated.csv' | test_text_path = data_path + 'twitter_test_translated.csv' | ||||
batch_size = 128 | batch_size = 128 | ||||
attention_weight_decay = 0.001 | attention_weight_decay = 0.001 | ||||
classification_weight_decay = 0.001 | classification_weight_decay = 0.001 | ||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||||
image_model_name = 'vit-base-patch16-224' | |||||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
image_embedding = 768 | image_embedding = 768 | ||||
text_encoder_model = "bert-base-uncased" | |||||
text_tokenizer = "bert-base-uncased" | |||||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
text_embedding = 768 | text_embedding = 768 | ||||
max_length = 32 | max_length = 32 | ||||
class TwitterDatasetLoader(DatasetLoader): | class TwitterDatasetLoader(DatasetLoader): | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
item = self.set_text(idx) | item = self.set_text(idx) | ||||
item.update(self.set_img(idx)) | |||||
return item | |||||
def set_img(self, idx): | |||||
try: | try: | ||||
if self.mode == 'train': | if self.mode == 'train': | ||||
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | ||||
except: | except: | ||||
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | ||||
image = asarray(image) | image = asarray(image) | ||||
self.set_image(image, item) | |||||
return item | |||||
return self.set_image(image) | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.text) | return len(self.text) | ||||
class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
self.config = config | |||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | |||||
if mode == 'train': | |||||
with open(config.train_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.train_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
else: | |||||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
def __getitem__(self, idx): | |||||
item = dict() | |||||
item['id'] = idx | |||||
item['image_embedding'] = self.image_embedding[idx] | |||||
item['text_embedding'] = self.text_embedding[idx] | |||||
item['label'] = self.labels[idx] | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) |
name = 'weibo' | name = 'weibo' | ||||
DatasetLoader = WeiboDatasetLoader | DatasetLoader = WeiboDatasetLoader | ||||
data_path = 'weibo/' | |||||
output_path = '' | |||||
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/' | |||||
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/' | |||||
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||||
# output_path = '' | |||||
rumor_image_path = data_path + 'rumor_images/' | rumor_image_path = data_path + 'rumor_images/' | ||||
nonrumor_image_path = data_path + 'nonrumor_images/' | nonrumor_image_path = data_path + 'nonrumor_images/' | ||||
train_text_path = data_path + 'weibo_train.csv' | train_text_path = data_path + 'weibo_train.csv' | ||||
validation_text_path = data_path + 'weibo_validation.csv' | |||||
validation_text_path = data_path + 'weibo_train.csv' | |||||
test_text_path = data_path + 'weibo_test.csv' | test_text_path = data_path + 'weibo_test.csv' | ||||
batch_size = 128 | |||||
batch_size = 64 | |||||
epochs = 100 | epochs = 100 | ||||
num_workers = 2 | |||||
num_workers = 4 | |||||
head_lr = 1e-03 | head_lr = 1e-03 | ||||
image_encoder_lr = 1e-02 | image_encoder_lr = 1e-02 | ||||
text_encoder_lr = 1e-05 | text_encoder_lr = 1e-05 | ||||
weight_decay = 0.001 | weight_decay = 0.001 | ||||
classification_lr = 1e-02 | classification_lr = 1e-02 | ||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||||
image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||||
image_embedding = 768 # 2048 | |||||
num_img_region = 64 # TODO | |||||
text_encoder_model = "bert-base-chinese" | |||||
text_tokenizer = "bert-base-chinese" | |||||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||||
image_embedding = 768 | |||||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||||
text_embedding = 768 | text_embedding = 768 | ||||
max_length = 200 | max_length = 200 | ||||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | ||||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | ||||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | ||||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | ||||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||||
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||||
class WeiboDatasetLoader(DatasetLoader): | class WeiboDatasetLoader(DatasetLoader): | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
item = self.set_text(idx) | item = self.set_text(idx) | ||||
item.update(self.set_img(idx)) | |||||
return item | |||||
def set_img(self, idx): | |||||
if self.labels[idx] == 1: | if self.labels[idx] == 1: | ||||
try: | try: | ||||
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | ||||
except: | except: | ||||
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | ||||
image = asarray(image) | image = asarray(image) | ||||
self.set_image(image, item) | |||||
return item | |||||
return self.set_image(image) | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.text) | return len(self.text) | ||||
class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||||
self.config = config | |||||
self.image_filenames = image_filenames | |||||
self.text = list(texts) | |||||
self.encoded_text = tokenizer | |||||
self.labels = labels | |||||
self.transforms = transforms | |||||
self.mode = mode | |||||
if mode == 'train': | |||||
with open(config.train_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.train_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
else: | |||||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||||
self.image_embedding = pickle.load(handle) | |||||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||||
self.text_embedding = pickle.load(handle) | |||||
def __getitem__(self, idx): | |||||
item = dict() | |||||
item['id'] = idx | |||||
item['image_embedding'] = self.image_embedding[idx] | |||||
item['text_embedding'] = self.text_embedding[idx] | |||||
item['label'] = self.labels[idx] | |||||
return item | |||||
def __len__(self): | |||||
return len(self.text) |
import albumentations as A | |||||
import numpy as np | import numpy as np | ||||
import pandas as pd | import pandas as pd | ||||
from sklearn.utils import compute_class_weight | from sklearn.utils import compute_class_weight | ||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
def get_transforms(config, mode="train"): | |||||
if mode == "train": | |||||
return A.Compose( | |||||
[ | |||||
A.Resize(config.size, config.size, always_apply=True), | |||||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
] | |||||
) | |||||
else: | |||||
return A.Compose( | |||||
[ | |||||
A.Resize(config.size, config.size, always_apply=True), | |||||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||||
] | |||||
) | |||||
def make_dfs(config): | def make_dfs(config): | ||||
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | ||||
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | ||||
unlabeled_dataframe = None | |||||
if config.has_unlabeled_data: | |||||
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||||
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||||
return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||||
def get_tokenizer(config): | |||||
if 'bigbird' in config.text_encoder_model: | |||||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
elif 'xlnet' in config.text_encoder_model: | |||||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||||
else: | |||||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||||
return tokenizer | |||||
return train_dataframe, test_dataframe, validation_dataframe | |||||
def build_loaders(config, dataframe, mode): | def build_loaders(config, dataframe, mode): | ||||
if 'resnet' in config.image_model_name: | |||||
transforms = get_transforms(config, mode=mode) | |||||
else: | |||||
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||||
dataset = config.DatasetLoader( | |||||
config, | |||||
dataframe["image"].values, | |||||
dataframe["text"].values, | |||||
dataframe["label"].values, | |||||
tokenizer=get_tokenizer(config), | |||||
transforms=transforms, | |||||
mode=mode, | |||||
) | |||||
dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode) | |||||
if mode != 'train': | if mode != 'train': | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset, | dataset, | ||||
batch_size=config.batch_size, | |||||
num_workers=config.num_workers, | |||||
batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||||
num_workers=config.num_workers // 2, | |||||
pin_memory=False, | pin_memory=False, | ||||
shuffle=False, | shuffle=False, | ||||
) | ) |
import random | |||||
import numpy as np | |||||
import pandas as pd | |||||
import torch | |||||
from matplotlib import pyplot as plt | |||||
from tqdm import tqdm | |||||
from data_loaders import make_dfs, build_loaders | |||||
from learner import batch_constructor | |||||
from model import FakeNewsModel | |||||
from lime.lime_image import LimeImageExplainer | |||||
from skimage.segmentation import mark_boundaries | |||||
from lime.lime_text import LimeTextExplainer | |||||
from utils import make_directory | |||||
def lime_(config, test_loader, model): | |||||
for i, batch in enumerate(test_loader): | |||||
batch = batch_constructor(config, batch) | |||||
with torch.no_grad(): | |||||
output, score = model(batch) | |||||
score = score.detach().cpu().numpy() | |||||
logit = model.logits.detach().cpu().numpy() | |||||
return score, logit | |||||
def lime_main(config, trial_number=None): | |||||
_, test_df, _ = make_dfs(config, ) | |||||
test_df = test_df[:1] | |||||
if trial_number: | |||||
try: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||||
except: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||||
else: | |||||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||||
try: | |||||
parameters = checkpoint['parameters'] | |||||
config.assign_hyperparameters(parameters) | |||||
except: | |||||
pass | |||||
model = FakeNewsModel(config).to(config.device) | |||||
try: | |||||
model.load_state_dict(checkpoint['model_state_dict']) | |||||
except: | |||||
model.load_state_dict(checkpoint) | |||||
model.eval() | |||||
torch.manual_seed(27) | |||||
random.seed(27) | |||||
np.random.seed(27) | |||||
text_explainer = LimeTextExplainer(class_names=config.classes) | |||||
image_explainer = LimeImageExplainer() | |||||
make_directory(config.output_path + '/lime/') | |||||
make_directory(config.output_path + '/lime/score/') | |||||
make_directory(config.output_path + '/lime/logit/') | |||||
make_directory(config.output_path + '/lime/text/') | |||||
make_directory(config.output_path + '/lime/image/') | |||||
def text_predict_proba(text): | |||||
scores = [] | |||||
print(len(text)) | |||||
for i in text: | |||||
row['text'] = i | |||||
test_loader = build_loaders(config, row, mode="lime") | |||||
score, _ = lime_(config, test_loader, model) | |||||
scores.append(score) | |||||
return np.array(scores) | |||||
def image_predict_proba(image): | |||||
scores = [] | |||||
print(len(image)) | |||||
for i in image: | |||||
test_loader = build_loaders(config, row, mode="lime") | |||||
test_loader['image'] = i.reshape((3, 224, 224)) | |||||
score, _ = lime_(config, test_loader, model) | |||||
scores.append(score) | |||||
return np.array(scores) | |||||
for i, row in test_df.iterrows(): | |||||
test_loader = build_loaders(config, row, mode="lime") | |||||
score, logit = lime_(config, test_loader, model) | |||||
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||||
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||||
text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||||
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||||
print('text', i, 'finished') | |||||
data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||||
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||||
top_labels=2, hide_color=0, num_samples=1) | |||||
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||||
hide_rest=False) | |||||
img_boundry = mark_boundaries(temp / 255.0, mask) | |||||
plt.imshow(img_boundry) | |||||
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||||
print('image', i, 'finished') |
import argparse | import argparse | ||||
from lime_main import lime_main | |||||
from optuna_main import optuna_main | from optuna_main import optuna_main | ||||
from test_main import test_main | from test_main import test_main | ||||
from torch_main import torch_main | from torch_main import torch_main | ||||
from data.weibo.config import WeiboConfig | from data.weibo.config import WeiboConfig | ||||
from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||
from utils import make_directory | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
import os | import os | ||||
parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
parser.add_argument('--data', type=str, required=True) | parser.add_argument('--data', type=str, required=True) | ||||
parser.add_argument('--use_optuna', type=int, required=False) | parser.add_argument('--use_optuna', type=int, required=False) | ||||
parser.add_argument('--use_lime', type=int, required=False) | |||||
parser.add_argument('--just_test', type=int, required=False) | parser.add_argument('--just_test', type=int, required=False) | ||||
parser.add_argument('--batch', type=int, required=False) | parser.add_argument('--batch', type=int, required=False) | ||||
parser.add_argument('--epoch', type=int, required=False) | parser.add_argument('--epoch', type=int, required=False) | ||||
config.output_path = str(args.extra) | config.output_path = str(args.extra) | ||||
use_optuna = False | use_optuna = False | ||||
try: | |||||
os.mkdir(config.output_path) | |||||
except OSError: | |||||
print("Creation of the directory failed") | |||||
else: | |||||
print("Successfully created the directory") | |||||
make_directory(config.output_path) | |||||
config.writer = SummaryWriter(config.output_path) | config.writer = SummaryWriter(config.output_path) | ||||
optuna_main(config, args.use_optuna) | optuna_main(config, args.use_optuna) | ||||
elif args.just_test: | elif args.just_test: | ||||
test_main(config, args.just_test) | test_main(config, args.just_test) | ||||
elif args.use_lime: | |||||
lime_main(config, args.use_lime) | |||||
else: | else: | ||||
torch_main(config) | torch_main(config) | ||||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | ||||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | ||||
self.classifier = Classifier(config) | self.classifier = Classifier(config) | ||||
self.temperature = config.temperature | |||||
class_weights = torch.FloatTensor(config.class_weights) | class_weights = torch.FloatTensor(config.class_weights) | ||||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | ||||
self.text_embeddings = self.text_projection(text_features) | self.text_embeddings = self.text_projection(text_features) | ||||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | ||||
score = self.classifier(self.multimodal_embeddings) | |||||
probs, output = torch.max(score.data, dim=1) | |||||
return output, score | |||||
class FakeNewsEmbeddingModel(nn.Module): | |||||
def __init__( | |||||
self, config | |||||
): | |||||
super().__init__() | |||||
self.config = config | |||||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||||
self.classifier = Classifier(config) | |||||
self.temperature = config.temperature | |||||
class_weights = torch.FloatTensor(config.class_weights) | |||||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||||
def forward(self, batch): | |||||
self.image_embeddings = self.image_projection(batch['image_embedding']) | |||||
self.text_embeddings = self.text_projection(batch['text_embedding']) | |||||
self.logits = (self.text_embeddings @ self.image_embeddings.T) | |||||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||||
score = self.classifier(self.multimodal_embeddings) | score = self.classifier(self.multimodal_embeddings) | ||||
probs, output = torch.max(score.data, dim=1) | probs, output = torch.max(score.data, dim=1) | ||||
return output, score | return output, score | ||||
def calculate_loss(model, score, label): | def calculate_loss(model, score, label): | ||||
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | ||||
# fake = (label == 1).nonzero() | # fake = (label == 1).nonzero() | ||||
def calculate_similarity_loss(config, image_embeddings, text_embeddings): | def calculate_similarity_loss(config, image_embeddings, text_embeddings): | ||||
# Calculating the Loss | # Calculating the Loss | ||||
logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||||
logits = (text_embeddings @ image_embeddings.T) | |||||
images_similarity = image_embeddings @ image_embeddings.T | images_similarity = image_embeddings @ image_embeddings.T | ||||
texts_similarity = text_embeddings @ text_embeddings.T | texts_similarity = text_embeddings @ text_embeddings.T | ||||
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||||
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1) | |||||
texts_loss = cross_entropy(logits, targets, reduction='none') | texts_loss = cross_entropy(logits, targets, reduction='none') | ||||
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | images_loss = cross_entropy(logits.T, targets.T, reduction='none') | ||||
loss = (images_loss + texts_loss) / 2.0 | loss = (images_loss + texts_loss) / 2.0 |
def optuna_main(config, n_trials=100): | def optuna_main(config, n_trials=100): | ||||
train_df, test_df, validation_df, _ = make_dfs(config) | |||||
train_df, test_df, validation_df = make_dfs(config) | |||||
train_loader = build_loaders(config, train_df, mode="train") | train_loader = build_loaders(config, train_df, mode="train") | ||||
validation_loader = build_loaders(config, validation_df, mode="validation") | validation_loader = build_loaders(config, validation_df, mode="validation") | ||||
test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") | ||||
trial = study.best_trial | trial = study.best_trial | ||||
print(' Number: ', trial.number) | print(' Number: ', trial.number) | ||||
print(" Value: ", trial.value) | print(" Value: ", trial.value) | ||||
s += 'number: ' + str(trial.number) + '\n' | |||||
s += 'value: ' + str(trial.value) + '\n' | s += 'value: ' + str(trial.value) + '\n' | ||||
print(" Params: ") | print(" Params: ") |
def test_main(config, trial_number=None): | def test_main(config, trial_number=None): | ||||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
train_df, test_df, validation_df = make_dfs(config, ) | |||||
test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") | ||||
test(config, test_loader, trial_number) | test(config, test_loader, trial_number) |
def torch_main(config): | def torch_main(config): | ||||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||||
train_df, test_df, validation_df = make_dfs(config, ) | |||||
train_loader = build_loaders(config, train_df, mode="train") | train_loader = build_loaders(config, train_df, mode="train") | ||||
validation_loader = build_loaders(config, validation_df, mode="validation") | validation_loader = build_loaders(config, validation_df, mode="validation") | ||||
test_loader = build_loaders(config, test_df, mode="test") | test_loader = build_loaders(config, test_df, mode="test") |
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import os | |||||
class AvgMeter: | class AvgMeter: | ||||
def __init__(self, name="Metric"): | def __init__(self, name="Metric"): | ||||
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | ||||
# torch.save(model.state_dict(), self.path) | # torch.save(model.state_dict(), self.path) | ||||
# self.val_loss_min = val_loss | # self.val_loss_min = val_loss | ||||
def make_directory(path): | |||||
try: | |||||
os.mkdir(path) | |||||
except OSError: | |||||
print("Creation of the directory failed") | |||||
else: | |||||
print("Successfully created the directory") |