Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. from image import ImageTransformerEncoder, ImageResnetEncoder
  5. from text import TextEncoder
  6. class ProjectionHead(nn.Module):
  7. def __init__(
  8. self,
  9. config,
  10. embedding_dim,
  11. ):
  12. super().__init__()
  13. self.config = config
  14. self.projection = nn.Linear(embedding_dim, config.projection_size)
  15. self.gelu = nn.GELU()
  16. self.fc = nn.Linear(config.projection_size, config.projection_size)
  17. self.dropout = nn.Dropout(config.dropout)
  18. self.layer_norm = nn.LayerNorm(config.projection_size)
  19. def forward(self, x):
  20. projected = self.projection(x)
  21. x = self.gelu(projected)
  22. x = self.fc(x)
  23. x = self.dropout(x)
  24. x = x + projected
  25. x = self.layer_norm(x)
  26. return x
  27. class Classifier(nn.Module):
  28. def __init__(self, config):
  29. super().__init__()
  30. self.layer_norm_1 = nn.LayerNorm(2 * config.projection_size)
  31. self.linear_layer = nn.Linear(2 * config.projection_size, config.hidden_size)
  32. self.gelu = nn.GELU()
  33. self.drop_out = nn.Dropout(config.dropout)
  34. self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
  35. self.classifier_layer = nn.Linear(config.hidden_size, config.class_num)
  36. self.softmax = nn.Softmax(dim=1)
  37. def forward(self, x):
  38. x = self.layer_norm_1(x)
  39. x = self.linear_layer(x)
  40. x = self.gelu(x)
  41. x = self.drop_out(x)
  42. self.embeddings = x = self.layer_norm_2(x)
  43. x = self.classifier_layer(x)
  44. x = self.softmax(x)
  45. return x
  46. class FakeNewsModel(nn.Module):
  47. def __init__(
  48. self, config
  49. ):
  50. super().__init__()
  51. self.config = config
  52. if 'resnet' in self.config.image_model_name:
  53. self.image_encoder = ImageResnetEncoder(config)
  54. else:
  55. self.image_encoder = ImageTransformerEncoder(config)
  56. self.text_encoder = TextEncoder(config)
  57. self.image_projection = ProjectionHead(config, embedding_dim=config.image_embedding)
  58. self.text_projection = ProjectionHead(config, embedding_dim=config.text_embedding)
  59. self.classifier = Classifier(config)
  60. class_weights = torch.FloatTensor(config.class_weights)
  61. self.classifier_loss_function = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
  62. self.text_embeddings = None
  63. self.image_embeddings = None
  64. self.multimodal_embeddings = None
  65. def forward(self, batch):
  66. image_features = self.image_encoder(ids=batch['id'], image=batch["image"])
  67. text_features = self.text_encoder(ids=batch['id'],
  68. input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
  69. self.image_embeddings = self.image_projection(image_features)
  70. self.text_embeddings = self.text_projection(text_features)
  71. self.multimodal_embeddings = torch.cat((self.image_embeddings, self.text_embeddings), dim=1)
  72. self.logits = (self.text_embeddings @ self.image_embeddings.T)
  73. score = self.classifier(self.multimodal_embeddings)
  74. probs, output = torch.max(score.data, dim=1)
  75. return output, score
  76. def calculate_loss(model, score, label):
  77. similarity = calculate_similarity_loss(model.config, model.image_embeddings, model.text_embeddings)
  78. # fake = (label == 1).nonzero()
  79. # real = (label == 0).nonzero()
  80. # s_loss = 0 * similarity[fake].mean() + similarity[real].mean()
  81. c_loss = model.classifier_loss_function(score, label)
  82. loss = model.config.loss_weight * c_loss + similarity.mean()
  83. return loss
  84. def calculate_similarity_loss(config, image_embeddings, text_embeddings):
  85. # Calculating the Loss
  86. logits = (text_embeddings @ image_embeddings.T)
  87. images_similarity = image_embeddings @ image_embeddings.T
  88. texts_similarity = text_embeddings @ text_embeddings.T
  89. targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
  90. texts_loss = cross_entropy(logits, targets, reduction='none')
  91. images_loss = cross_entropy(logits.T, targets.T, reduction='none')
  92. loss = (images_loss + texts_loss) / 2.0
  93. return loss
  94. def cross_entropy(preds, targets, reduction='none'):
  95. log_softmax = nn.LogSoftmax(dim=-1)
  96. loss = (-targets * log_softmax(preds)).sum(1)
  97. if reduction == "none":
  98. return loss
  99. elif reduction == "mean":
  100. return loss.mean()