@@ -59,9 +59,9 @@ class Config: | |||
size = 224 | |||
num_projection_layers = 1 | |||
projection_size = 256 | |||
projection_size = 64 | |||
dropout = 0.3 | |||
hidden_size = 256 | |||
hidden_size = 128 | |||
num_region = 64 # 16 #TODO | |||
region_size = projection_size // num_region | |||
@@ -1,17 +1,48 @@ | |||
import torch | |||
from torch.utils.data import Dataset | |||
import albumentations as A | |||
from transformers import ViTFeatureExtractor | |||
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
def get_transforms(config): | |||
return A.Compose( | |||
[ | |||
A.Resize(config.size, config.size, always_apply=True), | |||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
] | |||
) | |||
def get_tokenizer(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
elif 'xlnet' in config.text_encoder_model: | |||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
else: | |||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
return tokenizer | |||
class DatasetLoader(Dataset): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
def __init__(self, config, dataframe, mode): | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer( | |||
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt' | |||
) | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
if mode == 'lime': | |||
self.image_filenames = [dataframe["image"],] | |||
self.text = [dataframe["text"],] | |||
self.labels = [dataframe["label"],] | |||
else: | |||
self.image_filenames = dataframe["image"].values | |||
self.text = list(dataframe["text"].values) | |||
self.labels = dataframe["label"].values | |||
tokenizer = get_tokenizer(config) | |||
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt') | |||
if 'resnet' in config.image_model_name: | |||
self.transforms = get_transforms(config) | |||
else: | |||
self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
def set_text(self, idx): | |||
item = { | |||
@@ -23,11 +54,11 @@ class DatasetLoader(Dataset): | |||
item['id'] = idx | |||
return item | |||
def set_image(self, image, item): | |||
def set_image(self, image): | |||
if 'resnet' in self.config.image_model_name: | |||
image = self.transforms(image=image)['image'] | |||
item['image'] = torch.as_tensor(image).reshape((3, 224, 224)) | |||
return {'image': torch.as_tensor(image).reshape((3, 224, 224))} | |||
else: | |||
image = self.transforms(images=image, return_tensors='pt') | |||
image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] | |||
item['image'] = image.reshape((3, 224, 224)) | |||
return {'image': image.reshape((3, 224, 224))} |
@@ -8,15 +8,17 @@ class TwitterConfig(Config): | |||
name = 'twitter' | |||
DatasetLoader = TwitterDatasetLoader | |||
data_path = '/twitter/' | |||
output_path = '' | |||
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/' | |||
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/' | |||
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||
# output_path = '' | |||
train_image_path = data_path + 'images_train/' | |||
validation_image_path = data_path + 'images_validation/' | |||
validation_image_path = data_path + 'images_test/' | |||
test_image_path = data_path + 'images_test/' | |||
train_text_path = data_path + 'twitter_train_translated.csv' | |||
validation_text_path = data_path + 'twitter_validation_translated.csv' | |||
validation_text_path = data_path + 'twitter_test_translated.csv' | |||
test_text_path = data_path + 'twitter_test_translated.csv' | |||
batch_size = 128 | |||
@@ -32,12 +34,14 @@ class TwitterConfig(Config): | |||
attention_weight_decay = 0.001 | |||
classification_weight_decay = 0.001 | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||
image_model_name = 'vit-base-patch16-224' | |||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
image_embedding = 768 | |||
text_encoder_model = "bert-base-uncased" | |||
text_tokenizer = "bert-base-uncased" | |||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_embedding = 768 | |||
max_length = 32 | |||
@@ -10,7 +10,10 @@ from data.data_loader import DatasetLoader | |||
class TwitterDatasetLoader(DatasetLoader): | |||
def __getitem__(self, idx): | |||
item = self.set_text(idx) | |||
item.update(self.set_img(idx)) | |||
return item | |||
def set_img(self, idx): | |||
try: | |||
if self.mode == 'train': | |||
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") | |||
@@ -22,44 +25,8 @@ class TwitterDatasetLoader(DatasetLoader): | |||
except: | |||
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
image = asarray(image) | |||
self.set_image(image, item) | |||
return item | |||
return self.set_image(image) | |||
def __len__(self): | |||
return len(self.text) | |||
class TwitterEmbeddingDatasetLoader(DatasetLoader): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
if mode == 'train': | |||
with open(config.train_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.train_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
else: | |||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
def __getitem__(self, idx): | |||
item = dict() | |||
item['id'] = idx | |||
item['image_embedding'] = self.image_embedding[idx] | |||
item['text_embedding'] = self.text_embedding[idx] | |||
item['label'] = self.labels[idx] | |||
return item | |||
def __len__(self): | |||
return len(self.text) |
@@ -9,32 +9,35 @@ class WeiboConfig(Config): | |||
name = 'weibo' | |||
DatasetLoader = WeiboDatasetLoader | |||
data_path = 'weibo/' | |||
output_path = '' | |||
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/' | |||
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/' | |||
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/' | |||
# output_path = '' | |||
rumor_image_path = data_path + 'rumor_images/' | |||
nonrumor_image_path = data_path + 'nonrumor_images/' | |||
train_text_path = data_path + 'weibo_train.csv' | |||
validation_text_path = data_path + 'weibo_validation.csv' | |||
validation_text_path = data_path + 'weibo_train.csv' | |||
test_text_path = data_path + 'weibo_test.csv' | |||
batch_size = 128 | |||
batch_size = 64 | |||
epochs = 100 | |||
num_workers = 2 | |||
num_workers = 4 | |||
head_lr = 1e-03 | |||
image_encoder_lr = 1e-02 | |||
text_encoder_lr = 1e-05 | |||
weight_decay = 0.001 | |||
classification_lr = 1e-02 | |||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |||
image_model_name = 'vit-base-patch16-224' # 'resnet101' | |||
image_embedding = 768 # 2048 | |||
num_img_region = 64 # TODO | |||
text_encoder_model = "bert-base-chinese" | |||
text_tokenizer = "bert-base-chinese" | |||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224' | |||
image_embedding = 768 | |||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased" | |||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_embedding = 768 | |||
max_length = 200 | |||
@@ -53,10 +56,9 @@ class WeiboConfig(Config): | |||
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) | |||
self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) | |||
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1) | |||
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) | |||
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) | |||
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ]) | |||
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ]) | |||
@@ -11,7 +11,10 @@ from data.data_loader import DatasetLoader | |||
class WeiboDatasetLoader(DatasetLoader): | |||
def __getitem__(self, idx): | |||
item = self.set_text(idx) | |||
item.update(self.set_img(idx)) | |||
return item | |||
def set_img(self, idx): | |||
if self.labels[idx] == 1: | |||
try: | |||
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") | |||
@@ -24,43 +27,8 @@ class WeiboDatasetLoader(DatasetLoader): | |||
except: | |||
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') | |||
image = asarray(image) | |||
self.set_image(image, item) | |||
return item | |||
return self.set_image(image) | |||
def __len__(self): | |||
return len(self.text) | |||
class WeiboEmbeddingDatasetLoader(DatasetLoader): | |||
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): | |||
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) | |||
self.config = config | |||
self.image_filenames = image_filenames | |||
self.text = list(texts) | |||
self.encoded_text = tokenizer | |||
self.labels = labels | |||
self.transforms = transforms | |||
self.mode = mode | |||
if mode == 'train': | |||
with open(config.train_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.train_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
else: | |||
with open(config.validation_image_embedding_path, 'rb') as handle: | |||
self.image_embedding = pickle.load(handle) | |||
with open(config.validation_text_embedding_path, 'rb') as handle: | |||
self.text_embedding = pickle.load(handle) | |||
def __getitem__(self, idx): | |||
item = dict() | |||
item['id'] = idx | |||
item['image_embedding'] = self.image_embedding[idx] | |||
item['text_embedding'] = self.text_embedding[idx] | |||
item['label'] = self.labels[idx] | |||
return item | |||
def __len__(self): | |||
return len(self.text) |
@@ -1,27 +1,8 @@ | |||
import albumentations as A | |||
import numpy as np | |||
import pandas as pd | |||
from sklearn.utils import compute_class_weight | |||
from torch.utils.data import DataLoader | |||
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
def get_transforms(config, mode="train"): | |||
if mode == "train": | |||
return A.Compose( | |||
[ | |||
A.Resize(config.size, config.size, always_apply=True), | |||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
] | |||
) | |||
else: | |||
return A.Compose( | |||
[ | |||
A.Resize(config.size, config.size, always_apply=True), | |||
A.Normalize(max_pixel_value=255.0, always_apply=True), | |||
] | |||
) | |||
def make_dfs(config): | |||
@@ -51,44 +32,16 @@ def make_dfs(config): | |||
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) | |||
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) | |||
unlabeled_dataframe = None | |||
if config.has_unlabeled_data: | |||
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path) | |||
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True) | |||
return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe | |||
def get_tokenizer(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
elif 'xlnet' in config.text_encoder_model: | |||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
else: | |||
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer) | |||
return tokenizer | |||
return train_dataframe, test_dataframe, validation_dataframe | |||
def build_loaders(config, dataframe, mode): | |||
if 'resnet' in config.image_model_name: | |||
transforms = get_transforms(config, mode=mode) | |||
else: | |||
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name) | |||
dataset = config.DatasetLoader( | |||
config, | |||
dataframe["image"].values, | |||
dataframe["text"].values, | |||
dataframe["label"].values, | |||
tokenizer=get_tokenizer(config), | |||
transforms=transforms, | |||
mode=mode, | |||
) | |||
dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode) | |||
if mode != 'train': | |||
dataloader = DataLoader( | |||
dataset, | |||
batch_size=config.batch_size, | |||
num_workers=config.num_workers, | |||
batch_size=config.batch_size // 2 if mode == 'lime' else 1, | |||
num_workers=config.num_workers // 2, | |||
pin_memory=False, | |||
shuffle=False, | |||
) |
@@ -0,0 +1,110 @@ | |||
import random | |||
import numpy as np | |||
import pandas as pd | |||
import torch | |||
from matplotlib import pyplot as plt | |||
from tqdm import tqdm | |||
from data_loaders import make_dfs, build_loaders | |||
from learner import batch_constructor | |||
from model import FakeNewsModel | |||
from lime.lime_image import LimeImageExplainer | |||
from skimage.segmentation import mark_boundaries | |||
from lime.lime_text import LimeTextExplainer | |||
from utils import make_directory | |||
def lime_(config, test_loader, model): | |||
for i, batch in enumerate(test_loader): | |||
batch = batch_constructor(config, batch) | |||
with torch.no_grad(): | |||
output, score = model(batch) | |||
score = score.detach().cpu().numpy() | |||
logit = model.logits.detach().cpu().numpy() | |||
return score, logit | |||
def lime_main(config, trial_number=None): | |||
_, test_df, _ = make_dfs(config, ) | |||
test_df = test_df[:1] | |||
if trial_number: | |||
try: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt') | |||
except: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt') | |||
else: | |||
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device)) | |||
try: | |||
parameters = checkpoint['parameters'] | |||
config.assign_hyperparameters(parameters) | |||
except: | |||
pass | |||
model = FakeNewsModel(config).to(config.device) | |||
try: | |||
model.load_state_dict(checkpoint['model_state_dict']) | |||
except: | |||
model.load_state_dict(checkpoint) | |||
model.eval() | |||
torch.manual_seed(27) | |||
random.seed(27) | |||
np.random.seed(27) | |||
text_explainer = LimeTextExplainer(class_names=config.classes) | |||
image_explainer = LimeImageExplainer() | |||
make_directory(config.output_path + '/lime/') | |||
make_directory(config.output_path + '/lime/score/') | |||
make_directory(config.output_path + '/lime/logit/') | |||
make_directory(config.output_path + '/lime/text/') | |||
make_directory(config.output_path + '/lime/image/') | |||
def text_predict_proba(text): | |||
scores = [] | |||
print(len(text)) | |||
for i in text: | |||
row['text'] = i | |||
test_loader = build_loaders(config, row, mode="lime") | |||
score, _ = lime_(config, test_loader, model) | |||
scores.append(score) | |||
return np.array(scores) | |||
def image_predict_proba(image): | |||
scores = [] | |||
print(len(image)) | |||
for i in image: | |||
test_loader = build_loaders(config, row, mode="lime") | |||
test_loader['image'] = i.reshape((3, 224, 224)) | |||
score, _ = lime_(config, test_loader, model) | |||
scores.append(score) | |||
return np.array(scores) | |||
for i, row in test_df.iterrows(): | |||
test_loader = build_loaders(config, row, mode="lime") | |||
score, logit = lime_(config, test_loader, model) | |||
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",") | |||
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",") | |||
text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5) | |||
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html') | |||
print('text', i, 'finished') | |||
data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0) | |||
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba, | |||
top_labels=2, hide_color=0, num_samples=1) | |||
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5, | |||
hide_rest=False) | |||
img_boundry = mark_boundaries(temp / 255.0, mask) | |||
plt.imshow(img_boundry) | |||
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png') | |||
print('image', i, 'finished') |
@@ -1,5 +1,6 @@ | |||
import argparse | |||
from lime_main import lime_main | |||
from optuna_main import optuna_main | |||
from test_main import test_main | |||
from torch_main import torch_main | |||
@@ -7,6 +8,8 @@ from data.twitter.config import TwitterConfig | |||
from data.weibo.config import WeiboConfig | |||
from torch.utils.tensorboard import SummaryWriter | |||
from utils import make_directory | |||
if __name__ == '__main__': | |||
import os | |||
@@ -15,6 +18,7 @@ if __name__ == '__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--data', type=str, required=True) | |||
parser.add_argument('--use_optuna', type=int, required=False) | |||
parser.add_argument('--use_lime', type=int, required=False) | |||
parser.add_argument('--just_test', type=int, required=False) | |||
parser.add_argument('--batch', type=int, required=False) | |||
parser.add_argument('--epoch', type=int, required=False) | |||
@@ -44,12 +48,7 @@ if __name__ == '__main__': | |||
config.output_path = str(args.extra) | |||
use_optuna = False | |||
try: | |||
os.mkdir(config.output_path) | |||
except OSError: | |||
print("Creation of the directory failed") | |||
else: | |||
print("Successfully created the directory") | |||
make_directory(config.output_path) | |||
config.writer = SummaryWriter(config.output_path) | |||
@@ -57,6 +56,8 @@ if __name__ == '__main__': | |||
optuna_main(config, args.use_optuna) | |||
elif args.just_test: | |||
test_main(config, args.just_test) | |||
elif args.use_lime: | |||
lime_main(config, args.use_lime) | |||
else: | |||
torch_main(config) | |||
@@ -66,7 +66,6 @@ class FakeNewsModel(nn.Module): | |||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
self.classifier = Classifier(config) | |||
self.temperature = config.temperature | |||
class_weights = torch.FloatTensor(config.class_weights) | |||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
@@ -82,36 +81,15 @@ class FakeNewsModel(nn.Module): | |||
self.text_embeddings = self.text_projection(text_features) | |||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
score = self.classifier(self.multimodal_embeddings) | |||
probs, output = torch.max(score.data, dim=1) | |||
return output, score | |||
class FakeNewsEmbeddingModel(nn.Module): | |||
def __init__( | |||
self, config | |||
): | |||
super().__init__() | |||
self.config = config | |||
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) | |||
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) | |||
self.classifier = Classifier(config) | |||
self.temperature = config.temperature | |||
class_weights = torch.FloatTensor(config.class_weights) | |||
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') | |||
def forward(self, batch): | |||
self.image_embeddings = self.image_projection(batch['image_embedding']) | |||
self.text_embeddings = self.text_projection(batch['text_embedding']) | |||
self.logits = (self.text_embeddings @ self.image_embeddings.T) | |||
self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) | |||
score = self.classifier(self.multimodal_embeddings) | |||
probs, output = torch.max(score.data, dim=1) | |||
return output, score | |||
def calculate_loss(model, score, label): | |||
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) | |||
# fake = (label == 1).nonzero() | |||
@@ -124,10 +102,10 @@ def calculate_loss(model, score, label): | |||
def calculate_similarity_loss(config, image_embeddings, text_embeddings): | |||
# Calculating the Loss | |||
logits = (text_embeddings @ image_embeddings.T) / config.temperature | |||
logits = (text_embeddings @ image_embeddings.T) | |||
images_similarity = image_embeddings @ image_embeddings.T | |||
texts_similarity = text_embeddings @ text_embeddings.T | |||
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1) | |||
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1) | |||
texts_loss = cross_entropy(logits, targets, reduction='none') | |||
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |||
loss = (images_loss + texts_loss) / 2.0 |
@@ -18,7 +18,7 @@ def objective(trial, config, train_loader, validation_loader): | |||
def optuna_main(config, n_trials=100): | |||
train_df, test_df, validation_df, _ = make_dfs(config) | |||
train_df, test_df, validation_df = make_dfs(config) | |||
train_loader = build_loaders(config, train_df, mode="train") | |||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||
test_loader = build_loaders(config, test_df, mode="test") | |||
@@ -47,6 +47,7 @@ def optuna_main(config, n_trials=100): | |||
trial = study.best_trial | |||
print(' Number: ', trial.number) | |||
print(" Value: ", trial.value) | |||
s += 'number: ' + str(trial.number) + '\n' | |||
s += 'value: ' + str(trial.value) + '\n' | |||
print(" Params: ") |
@@ -106,6 +106,6 @@ def test(config, test_loader, trial_number=None): | |||
def test_main(config, trial_number=None): | |||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
train_df, test_df, validation_df = make_dfs(config, ) | |||
test_loader = build_loaders(config, test_df, mode="test") | |||
test(config, test_loader, trial_number) |
@@ -5,7 +5,7 @@ from test_main import test | |||
def torch_main(config): | |||
train_df, test_df, validation_df, _ = make_dfs(config, ) | |||
train_df, test_df, validation_df = make_dfs(config, ) | |||
train_loader = build_loaders(config, train_df, mode="train") | |||
validation_loader = build_loaders(config, validation_df, mode="validation") | |||
test_loader = build_loaders(config, test_df, mode="test") |
@@ -1,6 +1,6 @@ | |||
import numpy as np | |||
import torch | |||
import os | |||
class AvgMeter: | |||
def __init__(self, name="Metric"): | |||
@@ -91,3 +91,12 @@ class EarlyStopping: | |||
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') | |||
# torch.save(model.state_dict(), self.path) | |||
# self.val_loss_min = val_loss | |||
def make_directory(path): | |||
try: | |||
os.mkdir(path) | |||
except OSError: | |||
print("Creation of the directory failed") | |||
else: | |||
print("Successfully created the directory") |