import torch import torch.nn.functional as F from torch import nn from image import ImageTransformerEncoder, ImageResnetEncoder from text import TextEncoder import numpy as np class ProjectionHead(nn.Module): def __init__( self, config, embedding_dim, ): super().__init__() self.config = config self.projection = nn.Linear(embedding_dim, config.projection_size) self.gelu = nn.GELU() self.fc = nn.Linear(config.projection_size, config.projection_size) self.dropout = nn.Dropout(config.dropout) self.layer_norm = nn.LayerNorm(config.projection_size) def forward(self, x): projected = self.projection(x) x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) x = x + projected x = self.layer_norm(x) return x class Classifier(nn.Module): def __init__(self, config): super().__init__() self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size) self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size) self.gelu = nn.GELU() self.drop_out = nn.Dropout(config.dropout) self.layer_norm_2 = nn.LayerNorm(config.hidden_size) self.classifier_layer = nn.Linear(config.hidden_size, config.class_num) self.softmax = nn.Softmax(dim=1) def forward(self, x): x = self.layer_norm_1(x) x = self.linear_layer(x) x = self.gelu(x) x = self.drop_out(x) self.embeddings = x = self.layer_norm_2(x) x = self.classifier_layer(x) x = self.softmax(x) return x class FakeNewsModel(nn.Module): def __init__( self, config ): super().__init__() self.config = config if 'resnet' in self.config.image_model_name: self.image_encoder = ImageResnetEncoder(config) else: self.image_encoder = ImageTransformerEncoder(config) self.text_encoder = TextEncoder(config) self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding) self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding) self.classifier = Classifier(config) class_weights = torch.FloatTensor(config.class_weights) self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean') self.text_embeddings = None self.image_embeddings = None self.multimodal_embeddings = None def forward(self, batch): image_features = self.image_encoder(ids=batch['id'], image=batch["image"]) text_features = self.text_encoder(ids=batch['id'], input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) self.image_embeddings = self.image_projection(image_features) self.text_embeddings = self.text_projection(text_features) self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1) score = self.classifier(self.multimodal_embeddings) probs, output = torch.max(score.data, dim=1) self.similarity = self.text_embeddings @ self.image_embeddings.T return output, score def calculate_loss(model, score, label): s_loss = calculate_similarity_loss(model.image_embeddings, model.text_embeddings) c_loss = model.classifier_loss_function(score, label) loss = model.config.loss_weight * c_loss + s_loss return loss, c_loss, s_loss def calculate_similarity_loss(image_embeddings, text_embeddings): logits = (text_embeddings @ image_embeddings.T) images_similarity = (image_embeddings @ image_embeddings.T) texts_similarity = (text_embeddings @ text_embeddings.T) targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1) texts_loss = cross_entropy(logits, targets, reduction='mean') images_loss = cross_entropy(logits.T, targets.T, reduction='mean') loss = (images_loss + texts_loss) / 2.0 return loss def cross_entropy(preds, targets, reduction='none'): # entropy = torch.nn.CrossEntropyLoss(reduction=reduction) # return entropy(preds, targets) log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) if reduction == "none": return loss elif reduction == "mean": return loss.mean()