Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

text.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from torch import nn
  2. from transformers import BertModel, BertConfig, BertTokenizer, \
  3. BigBirdModel, BigBirdConfig, BigBirdTokenizer, \
  4. XLNetModel, XLNetConfig, XLNetTokenizer
  5. class TextEncoder(nn.Module):
  6. def __init__(self, config):
  7. super().__init__()
  8. self.config = config
  9. self.model = get_text_model(config)
  10. for p in self.model.parameters():
  11. p.requires_grad = self.config.trainable
  12. self.target_token_idx = 0
  13. self.text_encoder_embedding = dict()
  14. def forward(self, ids, input_ids, attention_mask):
  15. output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  16. last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]
  17. # for i, id in enumerate(ids):
  18. # id = int(id.detach().cpu().numpy())
  19. # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
  20. return last_hidden_state
  21. def get_text_model(config):
  22. if 'bigbird' in config.text_encoder_model:
  23. if config.pretrained:
  24. model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2)
  25. else:
  26. model = BigBirdModel(config=BigBirdConfig())
  27. elif 'xlnet' in config.text_encoder_model:
  28. if config.pretrained:
  29. model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
  30. output_hidden_states=True, return_dict=True)
  31. else:
  32. model = XLNetModel(config=XLNetConfig())
  33. else:
  34. if config.pretrained:
  35. model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False,
  36. output_hidden_states=True, return_dict=True)
  37. else:
  38. model = BertModel(config=BertConfig())
  39. return model