1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from torch import nn
- from transformers import BertModel, BertConfig, \
- RobertaModel, RobertaConfig, \
- XLNetModel, XLNetConfig
-
-
- class TextEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- self.config = config
- self.model = get_text_model(config)
- for p in self.model.parameters():
- p.requires_grad = self.config.trainable
- self.target_token_idx = 0
- self.text_encoder_embedding = dict()
-
- def forward(self, ids, input_ids, attention_mask):
- output = self.model(input_ids=input_ids, attention_mask=attention_mask)
- last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :]
-
- # for i, id in enumerate(ids):
- # id = int(id.detach().cpu().numpy())
- # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy()
- return last_hidden_state
-
-
- def get_text_model(config):
- if 'roberta' in config.text_encoder_model:
- if config.pretrained:
- model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False,
- output_hidden_states=True, return_dict=True)
- else:
- model = RobertaModel(config=RobertaConfig())
- elif 'xlnet' in config.text_encoder_model:
- if config.pretrained:
- model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
- output_hidden_states=True, return_dict=True)
- else:
- model = XLNetModel(config=XLNetConfig())
- else:
- if config.pretrained:
- model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False,
- output_hidden_states=True, return_dict=True)
- else:
- model = BertModel(config=BertConfig())
- return model
|