Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. from image import ImageTransformerEncoder, ImageResnetEncoder
  5. from text import TextEncoder
  6. import numpy as np
  7. class ProjectionHead(nn.Module):
  8. def __init__(
  9. self,
  10. config,
  11. embedding_dim,
  12. ):
  13. super().__init__()
  14. self.config = config
  15. self.projection = nn.Linear(embedding_dim, config.projection_size)
  16. self.gelu = nn.GELU()
  17. self.fc = nn.Linear(config.projection_size, config.projection_size)
  18. self.dropout = nn.Dropout(config.dropout)
  19. self.layer_norm = nn.LayerNorm(config.projection_size)
  20. def forward(self, x):
  21. projected = self.projection(x)
  22. x = self.gelu(projected)
  23. x = self.fc(x)
  24. x = self.dropout(x)
  25. x = x + projected
  26. x = self.layer_norm(x)
  27. return x
  28. class Classifier(nn.Module):
  29. def __init__(self, config):
  30. super().__init__()
  31. self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size)
  32. self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size)
  33. self.gelu = nn.GELU()
  34. self.drop_out = nn.Dropout(config.dropout)
  35. self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
  36. self.classifier_layer = nn.Linear(config.hidden_size, config.class_num)
  37. self.softmax = nn.Softmax(dim=1)
  38. def forward(self, x):
  39. x = self.layer_norm_1(x)
  40. x = self.linear_layer(x)
  41. x = self.gelu(x)
  42. x = self.drop_out(x)
  43. self.embeddings = x = self.layer_norm_2(x)
  44. x = self.classifier_layer(x)
  45. x = self.softmax(x)
  46. return x
  47. class FakeNewsModel(nn.Module):
  48. def __init__(
  49. self, config
  50. ):
  51. super().__init__()
  52. self.config = config
  53. if 'resnet' in self.config.image_model_name:
  54. self.image_encoder = ImageResnetEncoder(config)
  55. else:
  56. self.image_encoder = ImageTransformerEncoder(config)
  57. self.text_encoder = TextEncoder(config)
  58. self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
  59. self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
  60. self.classifier = Classifier(config)
  61. class_weights = torch.FloatTensor(config.class_weights)
  62. self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
  63. self.text_embeddings = None
  64. self.image_embeddings = None
  65. self.multimodal_embeddings = None
  66. def forward(self, batch):
  67. image_features = self.image_encoder(ids=batch['id'], image=batch["image"])
  68. text_features = self.text_encoder(ids=batch['id'],
  69. input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
  70. self.image_embeddings = self.image_projection(image_features)
  71. self.text_embeddings = self.text_projection(text_features)
  72. self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
  73. score = self.classifier(self.multimodal_embeddings)
  74. probs, output = torch.max(score.data, dim=1)
  75. self.similarity = self.text_embeddings @ self.image_embeddings.T
  76. return output, score
  77. def calculate_loss(model, score, label):
  78. s_loss = calculate_similarity_loss(model.image_embeddings, model.text_embeddings)
  79. c_loss = model.classifier_loss_function(score, label)
  80. loss = model.config.loss_weight * c_loss + s_loss
  81. return loss, c_loss, s_loss
  82. def calculate_similarity_loss(image_embeddings, text_embeddings):
  83. logits = (text_embeddings @ image_embeddings.T)
  84. images_similarity = (image_embeddings @ image_embeddings.T)
  85. texts_similarity = (text_embeddings @ text_embeddings.T)
  86. targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
  87. texts_loss = cross_entropy(logits, targets, reduction='mean')
  88. images_loss = cross_entropy(logits.T, targets.T, reduction='mean')
  89. loss = (images_loss + texts_loss) / 2.0
  90. return loss
  91. def cross_entropy(preds, targets, reduction='none'):
  92. # entropy = torch.nn.CrossEntropyLoss(reduction=reduction)
  93. # return entropy(preds, targets)
  94. log_softmax = nn.LogSoftmax(dim=-1)
  95. loss = (-targets * log_softmax(preds)).sum(1)
  96. if reduction == "none":
  97. return loss
  98. elif reduction == "mean":
  99. return loss.mean()