Browse Source

lime added

master
FaezeGhorbanpour 2 years ago
parent
commit
47c42ec9e8
14 changed files with 219 additions and 195 deletions
  1. 2
    2
      data/config.py
  2. 42
    11
      data/data_loader.py
  3. 12
    8
      data/twitter/config.py
  4. 4
    37
      data/twitter/data_loader.py
  5. 16
    14
      data/weibo/config.py
  6. 4
    36
      data/weibo/data_loader.py
  7. 4
    51
      data_loaders.py
  8. 110
    0
      lime_main.py
  9. 7
    6
      main.py
  10. 4
    26
      model.py
  11. 2
    1
      optuna_main.py
  12. 1
    1
      test_main.py
  13. 1
    1
      torch_main.py
  14. 10
    1
      utils.py

+ 2
- 2
data/config.py View File

size = 224 size = 224


num_projection_layers = 1 num_projection_layers = 1
projection_size = 256
projection_size = 64
dropout = 0.3 dropout = 0.3
hidden_size = 256
hidden_size = 128
num_region = 64 # 16 #TODO num_region = 64 # 16 #TODO
region_size = projection_size // num_region region_size = projection_size // num_region



+ 42
- 11
data/data_loader.py View File

import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import albumentations as A
from transformers import ViTFeatureExtractor
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer


def get_transforms(config):
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
return tokenizer



class DatasetLoader(Dataset): class DatasetLoader(Dataset):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
def __init__(self, config, dataframe, mode):
self.config = config self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer(
list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt'
)
self.labels = labels
self.transforms = transforms
self.mode = mode self.mode = mode
if mode == 'lime':
self.image_filenames = [dataframe["image"],]
self.text = [dataframe["text"],]
self.labels = [dataframe["label"],]
else:
self.image_filenames = dataframe["image"].values
self.text = list(dataframe["text"].values)
self.labels = dataframe["label"].values

tokenizer = get_tokenizer(config)
self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt')
if 'resnet' in config.image_model_name:
self.transforms = get_transforms(config)
else:
self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)


def set_text(self, idx): def set_text(self, idx):
item = { item = {
item['id'] = idx item['id'] = idx
return item return item


def set_image(self, image, item):
def set_image(self, image):
if 'resnet' in self.config.image_model_name: if 'resnet' in self.config.image_model_name:
image = self.transforms(image=image)['image'] image = self.transforms(image=image)['image']
item['image'] = torch.as_tensor(image).reshape((3, 224, 224))
return {'image': torch.as_tensor(image).reshape((3, 224, 224))}
else: else:
image = self.transforms(images=image, return_tensors='pt') image = self.transforms(images=image, return_tensors='pt')
image = image.convert_to_tensors(tensor_type='pt')['pixel_values'] image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
item['image'] = image.reshape((3, 224, 224))
return {'image': image.reshape((3, 224, 224))}

+ 12
- 8
data/twitter/config.py View File

name = 'twitter' name = 'twitter'
DatasetLoader = TwitterDatasetLoader DatasetLoader = TwitterDatasetLoader


data_path = '/twitter/'
output_path = ''
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/'
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/'
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/'
# output_path = ''


train_image_path = data_path + 'images_train/' train_image_path = data_path + 'images_train/'
validation_image_path = data_path + 'images_validation/'
validation_image_path = data_path + 'images_test/'
test_image_path = data_path + 'images_test/' test_image_path = data_path + 'images_test/'


train_text_path = data_path + 'twitter_train_translated.csv' train_text_path = data_path + 'twitter_train_translated.csv'
validation_text_path = data_path + 'twitter_validation_translated.csv'
validation_text_path = data_path + 'twitter_test_translated.csv'
test_text_path = data_path + 'twitter_test_translated.csv' test_text_path = data_path + 'twitter_test_translated.csv'


batch_size = 128 batch_size = 128
attention_weight_decay = 0.001 attention_weight_decay = 0.001
classification_weight_decay = 0.001 classification_weight_decay = 0.001


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


image_model_name = 'vit-base-patch16-224'
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768 image_embedding = 768
text_encoder_model = "bert-base-uncased"
text_tokenizer = "bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768 text_embedding = 768
max_length = 32 max_length = 32



+ 4
- 37
data/twitter/data_loader.py View File

class TwitterDatasetLoader(DatasetLoader): class TwitterDatasetLoader(DatasetLoader):
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.set_text(idx) item = self.set_text(idx)
item.update(self.set_img(idx))
return item


def set_img(self, idx):
try: try:
if self.mode == 'train': if self.mode == 'train':
image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}")
except: except:
image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image) image = asarray(image)
self.set_image(image, item)
return item

return self.set_image(image)


def __len__(self): def __len__(self):
return len(self.text) return len(self.text)



class TwitterEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 16
- 14
data/weibo/config.py View File

name = 'weibo' name = 'weibo'
DatasetLoader = WeiboDatasetLoader DatasetLoader = WeiboDatasetLoader


data_path = 'weibo/'
output_path = ''
data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/weibo/'
# data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/weibo/'
output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/'
# output_path = ''


rumor_image_path = data_path + 'rumor_images/' rumor_image_path = data_path + 'rumor_images/'
nonrumor_image_path = data_path + 'nonrumor_images/' nonrumor_image_path = data_path + 'nonrumor_images/'


train_text_path = data_path + 'weibo_train.csv' train_text_path = data_path + 'weibo_train.csv'
validation_text_path = data_path + 'weibo_validation.csv'
validation_text_path = data_path + 'weibo_train.csv'
test_text_path = data_path + 'weibo_test.csv' test_text_path = data_path + 'weibo_test.csv'


batch_size = 128
batch_size = 64
epochs = 100 epochs = 100
num_workers = 2
num_workers = 4
head_lr = 1e-03 head_lr = 1e-03
image_encoder_lr = 1e-02 image_encoder_lr = 1e-02
text_encoder_lr = 1e-05 text_encoder_lr = 1e-05
weight_decay = 0.001 weight_decay = 0.001
classification_lr = 1e-02 classification_lr = 1e-02


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


image_model_name = 'vit-base-patch16-224' # 'resnet101'
image_embedding = 768 # 2048
num_img_region = 64 # TODO
text_encoder_model = "bert-base-chinese"
text_tokenizer = "bert-base-chinese"
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768 text_embedding = 768
max_length = 200 max_length = 200


self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1) self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)


self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1) self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1) self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)


self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64]) self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])



+ 4
- 36
data/weibo/data_loader.py View File

class WeiboDatasetLoader(DatasetLoader): class WeiboDatasetLoader(DatasetLoader):
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.set_text(idx) item = self.set_text(idx)
item.update(self.set_img(idx))
return item


def set_img(self, idx):
if self.labels[idx] == 1: if self.labels[idx] == 1:
try: try:
image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}") image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}")
except: except:
image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB') image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
image = asarray(image) image = asarray(image)
self.set_image(image, item)
return item
return self.set_image(image)


def __len__(self): def __len__(self):
return len(self.text) return len(self.text)



class WeiboEmbeddingDatasetLoader(DatasetLoader):
def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
self.config = config
self.image_filenames = image_filenames
self.text = list(texts)
self.encoded_text = tokenizer
self.labels = labels
self.transforms = transforms
self.mode = mode

if mode == 'train':
with open(config.train_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.train_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)
else:
with open(config.validation_image_embedding_path, 'rb') as handle:
self.image_embedding = pickle.load(handle)
with open(config.validation_text_embedding_path, 'rb') as handle:
self.text_embedding = pickle.load(handle)


def __getitem__(self, idx):
item = dict()
item['id'] = idx
item['image_embedding'] = self.image_embedding[idx]
item['text_embedding'] = self.text_embedding[idx]
item['label'] = self.labels[idx]
return item

def __len__(self):
return len(self.text)

+ 4
- 51
data_loaders.py View File

import albumentations as A
import numpy as np import numpy as np
import pandas as pd import pandas as pd


from sklearn.utils import compute_class_weight from sklearn.utils import compute_class_weight
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer


def get_transforms(config, mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(config.size, config.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)




def make_dfs(config): def make_dfs(config):
validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True) validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x)) validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))


unlabeled_dataframe = None
if config.has_unlabeled_data:
unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path)
unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True)

return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:
tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
return tokenizer
return train_dataframe, test_dataframe, validation_dataframe




def build_loaders(config, dataframe, mode): def build_loaders(config, dataframe, mode):
if 'resnet' in config.image_model_name:
transforms = get_transforms(config, mode=mode)
else:
transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)

dataset = config.DatasetLoader(
config,
dataframe["image"].values,
dataframe["text"].values,
dataframe["label"].values,
tokenizer=get_tokenizer(config),
transforms=transforms,
mode=mode,
)
dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode)
if mode != 'train': if mode != 'train':
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=config.batch_size,
num_workers=config.num_workers,
batch_size=config.batch_size // 2 if mode == 'lime' else 1,
num_workers=config.num_workers // 2,
pin_memory=False, pin_memory=False,
shuffle=False, shuffle=False,
) )

+ 110
- 0
lime_main.py View File

import random

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from learner import batch_constructor
from model import FakeNewsModel

from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
from lime.lime_text import LimeTextExplainer

from utils import make_directory


def lime_(config, test_loader, model):
for i, batch in enumerate(test_loader):
batch = batch_constructor(config, batch)
with torch.no_grad():
output, score = model(batch)

score = score.detach().cpu().numpy()

logit = model.logits.detach().cpu().numpy()

return score, logit


def lime_main(config, trial_number=None):
_, test_df, _ = make_dfs(config, )
test_df = test_df[:1]

if trial_number:
try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
else:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))

try:
parameters = checkpoint['parameters']
config.assign_hyperparameters(parameters)
except:
pass

model = FakeNewsModel(config).to(config.device)
try:
model.load_state_dict(checkpoint['model_state_dict'])
except:
model.load_state_dict(checkpoint)

model.eval()

torch.manual_seed(27)
random.seed(27)
np.random.seed(27)

text_explainer = LimeTextExplainer(class_names=config.classes)
image_explainer = LimeImageExplainer()

make_directory(config.output_path + '/lime/')
make_directory(config.output_path + '/lime/score/')
make_directory(config.output_path + '/lime/logit/')
make_directory(config.output_path + '/lime/text/')
make_directory(config.output_path + '/lime/image/')

def text_predict_proba(text):
scores = []
print(len(text))
for i in text:
row['text'] = i
test_loader = build_loaders(config, row, mode="lime")
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

def image_predict_proba(image):
scores = []
print(len(image))
for i in image:
test_loader = build_loaders(config, row, mode="lime")
test_loader['image'] = i.reshape((3, 224, 224))
score, _ = lime_(config, test_loader, model)
scores.append(score)
return np.array(scores)

for i, row in test_df.iterrows():
test_loader = build_loaders(config, row, mode="lime")
score, logit = lime_(config, test_loader, model)
np.savetxt(config.output_path + '/lime/score/' + str(i) + '.csv', score, delimiter=",")
np.savetxt(config.output_path + '/lime/logit/' + str(i) + '.csv', logit, delimiter=",")

text_exp = text_explainer.explain_instance(row['text'], text_predict_proba, num_features=5)
text_exp.save_to_file(config.output_path + '/lime/text/' + str(i) + '.html')
print('text', i, 'finished')

data_items = config.DatasetLoader(config, dataframe=row, mode='lime').__getitem__(0)
img_exp = image_explainer.explain_instance(data_items['image'].reshape((224, 224, 3)), image_predict_proba,
top_labels=2, hide_color=0, num_samples=1)
temp, mask = img_exp.get_image_and_mask(img_exp.top_labels[0], positive_only=False, num_features=5,
hide_rest=False)
img_boundry = mark_boundaries(temp / 255.0, mask)
plt.imshow(img_boundry)
plt.savefig(config.output_path + '/lime/image/' + str(i) + '.png')
print('image', i, 'finished')

+ 7
- 6
main.py View File

import argparse import argparse


from lime_main import lime_main
from optuna_main import optuna_main from optuna_main import optuna_main
from test_main import test_main from test_main import test_main
from torch_main import torch_main from torch_main import torch_main
from data.weibo.config import WeiboConfig from data.weibo.config import WeiboConfig
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter


from utils import make_directory

if __name__ == '__main__': if __name__ == '__main__':
import os import os


parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, required=True) parser.add_argument('--data', type=str, required=True)
parser.add_argument('--use_optuna', type=int, required=False) parser.add_argument('--use_optuna', type=int, required=False)
parser.add_argument('--use_lime', type=int, required=False)
parser.add_argument('--just_test', type=int, required=False) parser.add_argument('--just_test', type=int, required=False)
parser.add_argument('--batch', type=int, required=False) parser.add_argument('--batch', type=int, required=False)
parser.add_argument('--epoch', type=int, required=False) parser.add_argument('--epoch', type=int, required=False)
config.output_path = str(args.extra) config.output_path = str(args.extra)
use_optuna = False use_optuna = False


try:
os.mkdir(config.output_path)
except OSError:
print("Creation of the directory failed")
else:
print("Successfully created the directory")
make_directory(config.output_path)


config.writer = SummaryWriter(config.output_path) config.writer = SummaryWriter(config.output_path)


optuna_main(config, args.use_optuna) optuna_main(config, args.use_optuna)
elif args.just_test: elif args.just_test:
test_main(config, args.just_test) test_main(config, args.just_test)
elif args.use_lime:
lime_main(config, args.use_lime)
else: else:
torch_main(config) torch_main(config)



+ 4
- 26
model.py View File

self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config) self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights) class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')


self.text_embeddings = self.text_projection(text_features) self.text_embeddings = self.text_projection(text_features)


self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score


class FakeNewsEmbeddingModel(nn.Module):
def __init__(
self, config
):
super().__init__()
self.config = config
self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
self.classifier = Classifier(config)
self.temperature = config.temperature
class_weights = torch.FloatTensor(config.class_weights)
self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')

def forward(self, batch):
self.image_embeddings = self.image_projection(batch['image_embedding'])
self.text_embeddings = self.text_projection(batch['text_embedding'])
self.logits = (self.text_embeddings @ self.image_embeddings.T)


self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
score = self.classifier(self.multimodal_embeddings) score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1) probs, output = torch.max(score.data, dim=1)


return output, score return output, score





def calculate_loss(model, score, label): def calculate_loss(model, score, label):
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings) similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings)
# fake = (label == 1).nonzero() # fake = (label == 1).nonzero()


def calculate_similarity_loss(config, image_embeddings, text_embeddings): def calculate_similarity_loss(config, image_embeddings, text_embeddings):
# Calculating the Loss # Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / config.temperature
logits = (text_embeddings @ image_embeddings.T)
images_similarity = image_embeddings @ image_embeddings.T images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax((images_similarity + texts_similarity) / 2 * config.temperature, dim=-1)
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none') texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none') images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 loss = (images_loss + texts_loss) / 2.0

+ 2
- 1
optuna_main.py View File





def optuna_main(config, n_trials=100): def optuna_main(config, n_trials=100):
train_df, test_df, validation_df, _ = make_dfs(config)
train_df, test_df, validation_df = make_dfs(config)
train_loader = build_loaders(config, train_df, mode="train") train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation") validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test") test_loader = build_loaders(config, test_df, mode="test")
trial = study.best_trial trial = study.best_trial
print(' Number: ', trial.number) print(' Number: ', trial.number)
print(" Value: ", trial.value) print(" Value: ", trial.value)
s += 'number: ' + str(trial.number) + '\n'
s += 'value: ' + str(trial.value) + '\n' s += 'value: ' + str(trial.value) + '\n'


print(" Params: ") print(" Params: ")

+ 1
- 1
test_main.py View File





def test_main(config, trial_number=None): def test_main(config, trial_number=None):
train_df, test_df, validation_df, _ = make_dfs(config, )
train_df, test_df, validation_df = make_dfs(config, )
test_loader = build_loaders(config, test_df, mode="test") test_loader = build_loaders(config, test_df, mode="test")
test(config, test_loader, trial_number) test(config, test_loader, trial_number)

+ 1
- 1
torch_main.py View File





def torch_main(config): def torch_main(config):
train_df, test_df, validation_df, _ = make_dfs(config, )
train_df, test_df, validation_df = make_dfs(config, )
train_loader = build_loaders(config, train_df, mode="train") train_loader = build_loaders(config, train_df, mode="train")
validation_loader = build_loaders(config, validation_df, mode="validation") validation_loader = build_loaders(config, validation_df, mode="validation")
test_loader = build_loaders(config, test_df, mode="test") test_loader = build_loaders(config, test_df, mode="test")

+ 10
- 1
utils.py View File

import numpy as np import numpy as np
import torch import torch
import os


class AvgMeter: class AvgMeter:
def __init__(self, name="Metric"): def __init__(self, name="Metric"):
# self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...') # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
# torch.save(model.state_dict(), self.path) # torch.save(model.state_dict(), self.path)
# self.val_loss_min = val_loss # self.val_loss_min = val_loss


def make_directory(path):
try:
os.mkdir(path)
except OSError:
print("Creation of the directory failed")
else:
print("Successfully created the directory")

Loading…
Cancel
Save