|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import torch
- import torch.nn.functional as F
- from torch import nn
-
- from image import ImageTransformerEncoder, ImageResnetEncoder
- from text import TextEncoder
- import numpy as np
-
-
- class ProjectionHead(nn.Module):
- def __init__(
- self,
- config,
- embedding_dim,
- ):
- super().__init__()
- self.config = config
- self.projection = nn.Linear(embedding_dim, config.projection_size)
- self.gelu = nn.GELU()
- self.fc = nn.Linear(config.projection_size, config.projection_size)
- self.dropout = nn.Dropout(config.dropout)
- self.layer_norm = nn.LayerNorm(config.projection_size)
-
- def forward(self, x):
- projected = self.projection(x)
- x = self.gelu(projected)
- x = self.fc(x)
- x = self.dropout(x)
- x = x + projected
- x = self.layer_norm(x)
- return x
-
-
- class Classifier(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size)
- self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size)
- self.gelu = nn.GELU()
- self.drop_out = nn.Dropout(config.dropout)
- self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
- self.classifier_layer = nn.Linear(config.hidden_size, config.class_num)
- self.softmax = nn.Softmax(dim=1)
-
- def forward(self, x):
- x = self.layer_norm_1(x)
- x = self.linear_layer(x)
- x = self.gelu(x)
- x = self.drop_out(x)
- self.embeddings = x = self.layer_norm_2(x)
- x = self.classifier_layer(x)
- x = self.softmax(x)
- return x
-
-
- class FakeNewsModel(nn.Module):
- def __init__(
- self, config
- ):
- super().__init__()
- self.config = config
- if 'resnet' in self.config.image_model_name:
- self.image_encoder = ImageResnetEncoder(config)
- else:
- self.image_encoder = ImageTransformerEncoder(config)
- self.text_encoder = TextEncoder(config)
- self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
- self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
- self.classifier = Classifier(config)
- class_weights = torch.FloatTensor(config.class_weights)
- self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
-
- self.text_embeddings = None
- self.image_embeddings = None
- self.multimodal_embeddings = None
-
- def forward(self, batch):
- image_features = self.image_encoder(ids=batch['id'], image=batch["image"])
- text_features = self.text_encoder(ids=batch['id'],
- input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
- self.image_embeddings = self.image_projection(image_features)
- self.text_embeddings = self.text_projection(text_features)
-
- self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
-
- score = self.classifier(self.multimodal_embeddings)
- probs, output = torch.max(score.data, dim=1)
-
- self.similarity = self.text_embeddings @ self.image_embeddings.T
-
- return output, score
-
-
- def calculate_loss(model, score, label):
- s_loss = calculate_similarity_loss(model.image_embeddings, model.text_embeddings)
- c_loss = model.classifier_loss_function(score, label)
- loss = model.config.loss_weight * c_loss + s_loss
- return loss, c_loss, s_loss
-
-
- def calculate_similarity_loss(image_embeddings, text_embeddings):
- logits = (text_embeddings @ image_embeddings.T)
- images_similarity = (image_embeddings @ image_embeddings.T)
- texts_similarity = (text_embeddings @ text_embeddings.T)
- targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
- texts_loss = cross_entropy(logits, targets, reduction='mean')
- images_loss = cross_entropy(logits.T, targets.T, reduction='mean')
- loss = (images_loss + texts_loss) / 2.0
- return loss
-
-
- def cross_entropy(preds, targets, reduction='none'):
- # entropy = torch.nn.CrossEntropyLoss(reduction=reduction)
- # return entropy(preds, targets)
- log_softmax = nn.LogSoftmax(dim=-1)
- loss = (-targets * log_softmax(preds)).sum(1)
- if reduction == "none":
- return loss
- elif reduction == "mean":
- return loss.mean()
|