Browse Source

similarity saving added

master
Faeze 1 year ago
parent
commit
2afb3d3e28
4 changed files with 34 additions and 35 deletions
  1. 8
    8
      data/twitter/config.py
  2. 6
    6
      data/weibo/config.py
  3. 12
    14
      model.py
  4. 8
    7
      test_main.py

+ 8
- 8
data/twitter/config.py View File

@@ -38,13 +38,13 @@ class TwitterConfig(Config):
attention_weight_decay = 0.05
classification_weight_decay = 0.05

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 32
@@ -65,9 +65,9 @@ class TwitterConfig(Config):
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)

self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
# self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)
self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
# self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

# self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])

+ 6
- 6
data/weibo/config.py View File

@@ -33,13 +33,13 @@ class WeiboConfig(Config):
projection_size = 64
dropout = 0.5

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-chinese"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-chinese"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-chinese"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-chinese"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 200
@@ -59,9 +59,9 @@ class WeiboConfig(Config):
self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)

self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)
# self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)

self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
# self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])


+ 12
- 14
model.py View File

@@ -4,6 +4,7 @@ from torch import nn

from image import ImageTransformerEncoder, ImageResnetEncoder
from text import TextEncoder
import numpy as np


class ProjectionHead(nn.Module):
@@ -18,7 +19,7 @@ class ProjectionHead(nn.Module):
self.gelu = nn.GELU()
self.fc = nn.Linear(config.projection_size, config.projection_size)
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.LayerNorm(config.projection_size)
self.layer_norm = nn.BatchNorm1d(config.projection_size)

def forward(self, x):
projected = self.projection(x)
@@ -81,39 +82,36 @@ class FakeNewsModel(nn.Module):
self.text_embeddings = self.text_projection(text_features)

self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
self.logits = (self.text_embeddings @ self.image_embeddings.T)

score = self.classifier(self.multimodal_embeddings)
probs, output = torch.max(score.data, dim=1)

return output, score
self.similarity = self.text_embeddings @ self.image_embeddings.T

return output, score


def calculate_loss(model, score, label):
similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings)
# fake = (label == 1).nonzero()
# real = (label == 0).nonzero()
# s_loss = 0 * similarity[fake].mean() + similarity[real].mean()
s_loss = similarity.mean()
s_loss = calculate_similarity_loss(model.image_embeddings, model.text_embeddings)
c_loss = model.classifier_loss_function(score, label)
loss = model.config.loss_weight * c_loss + s_loss
return loss, c_loss, s_loss


def calculate_similarity_loss(config, image_embeddings, text_embeddings):
# Calculating the Loss
def calculate_similarity_loss(image_embeddings, text_embeddings):
logits = (text_embeddings @ image_embeddings.T)
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
images_similarity = (image_embeddings @ image_embeddings.T)
texts_similarity = (text_embeddings @ text_embeddings.T)
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
texts_loss = cross_entropy(logits, targets, reduction='mean')
images_loss = cross_entropy(logits.T, targets.T, reduction='mean')
loss = (images_loss + texts_loss) / 2.0
return loss


def cross_entropy(preds, targets, reduction='none'):
# entropy = torch.nn.CrossEntropyLoss(reduction=reduction)
# return entropy(preds, targets)
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":

+ 8
- 7
test_main.py View File

@@ -13,13 +13,11 @@ from model import FakeNewsModel, calculate_loss


def test(config, test_loader, trial_number=None):
if trial_number:
try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
else:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))

try:
checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
except:
checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')

try:
parameters = checkpoint['parameters']
@@ -49,6 +47,7 @@ def test(config, test_loader, trial_number=None):
scores = []
ids = []
losses = []
similarities = []
tqdm_object = tqdm(test_loader, total=len(test_loader))
for i, batch in enumerate(tqdm_object):
batch = batch_constructor(config, batch)
@@ -64,6 +63,7 @@ def test(config, test_loader, trial_number=None):
text_features.append(model.text_embeddings.detach())
multimodal_features.append(model.multimodal_embeddings.detach())
concat_features.append(model.classifier.embeddings.detach())
similarities.append(model.similarity.detach())
losses.append((loss.detach(), c_loss.detach(), s_loss.detach()))

s = ''
@@ -79,6 +79,7 @@ def test(config, test_loader, trial_number=None):
save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv')
save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
save_embedding(similarities, fname=str(config.output_path) + '/similarities.tsv')

config_parameters = str(config)
with open(config.output_path + '/new_parameters.txt', 'w') as f:

Loading…
Cancel
Save