from torch import nn from transformers import BertModel, BertConfig, \ RobertaModel, RobertaConfig, \ XLNetModel, XLNetConfig class TextEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.model = get_text_model(config) for p in self.model.parameters(): p.requires_grad = self.config.trainable self.target_token_idx = 0 self.text_encoder_embedding = dict() def forward(self, ids, input_ids, attention_mask): output = self.model(input_ids=input_ids, attention_mask=attention_mask) last_hidden_state = output.last_hidden_state[:, self.target_token_idx, :] # for i, id in enumerate(ids): # id = int(id.detach().cpu().numpy()) # self.text_encoder_embedding[id] = last_hidden_state[i].detach().cpu().numpy() return last_hidden_state def get_text_model(config): if 'roberta' in config.text_encoder_model: if config.pretrained: model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False, output_hidden_states=True, return_dict=True) else: model = RobertaModel(config=RobertaConfig()) elif 'xlnet' in config.text_encoder_model: if config.pretrained: model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, output_hidden_states=True, return_dict=True) else: model = XLNetModel(config=XLNetConfig()) else: if config.pretrained: model = BertModel.from_pretrained(config.text_encoder_model, output_attentions=False, output_hidden_states=True, return_dict=True) else: model = BertModel(config=BertConfig()) return model