Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

image.py 1.4KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import timm
  2. from torch import nn
  3. from transformers import ViTModel, ViTConfig
  4. class ImageTransformerEncoder(nn.Module):
  5. def __init__(self, config):
  6. super().__init__()
  7. if config.pretrained:
  8. self.model = ViTModel.from_pretrained(config.image_model_name, output_attentions=False,
  9. output_hidden_states=True, return_dict=True)
  10. else:
  11. self.model = ViTModel(config=ViTConfig())
  12. for p in self.model.parameters():
  13. p.requires_grad = config.trainable
  14. self.target_token_idx = 0
  15. self.image_encoder_embedding = dict()
  16. def forward(self, ids, image):
  17. output = self.model(image)
  18. last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]
  19. # for i, id in enumerate(ids):
  20. # id = int(id.detach().cpu().numpy())
  21. # self.image_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
  22. return last_hidden_state
  23. class ImageResnetEncoder(nn.Module):
  24. def __init__(
  25. self, config
  26. ):
  27. super().__init__()
  28. self.model = timm.create_model(
  29. config.image_model_name, config.pretrained, num_classes=0, global_pool="avg"
  30. )
  31. for p in self.model.parameters():
  32. p.requires_grad = config.trainable
  33. def forward(self, ids, image):
  34. return self.model(image)