Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

text.py 1.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from torch import nn
  2. from transformers import BertModel, BertConfig, \
  3. RobertaModel, RobertaConfig, \
  4. XLNetModel, XLNetConfig
  5. class TextEncoder(nn.Module):
  6. def __init__(self, config):
  7. super().__init__()
  8. self.config = config
  9. self.model = get_text_model(config)
  10. for p in self.model.parameters():
  11. p.requires_grad = self.config.trainable
  12. self.target_token_idx = 0
  13. self.text_encoder_embedding = dict()
  14. def forward(self, ids, input_ids, attention_mask):
  15. output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  16. last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]
  17. # for i, id in enumerate(ids):
  18. # id = int(id.detach().cpu().numpy())
  19. # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
  20. return last_hidden_state
  21. def get_text_model(config):
  22. if 'roberta' in config.text_encoder_model:
  23. if config.pretrained:
  24. model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False,
  25. output_hidden_states=True, return_dict=True)
  26. else:
  27. model = RobertaModel(config=RobertaConfig())
  28. elif 'xlnet' in config.text_encoder_model:
  29. if config.pretrained:
  30. model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
  31. output_hidden_states=True, return_dict=True)
  32. else:
  33. model = XLNetModel(config=XLNetConfig())
  34. else:
  35. if config.pretrained:
  36. model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False,
  37. output_hidden_states=True, return_dict=True)
  38. else:
  39. model = BertModel(config=BertConfig())
  40. return model