| @@ -2,7 +2,7 @@ import torch | |||
| from torch.utils.data import Dataset | |||
| import albumentations as A | |||
| from transformers import ViTFeatureExtractor | |||
| from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
| from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||
| def get_transforms(config): | |||
| @@ -15,8 +15,8 @@ def get_transforms(config): | |||
| def get_tokenizer(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
| if 'roberta' in config.text_encoder_model: | |||
| tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
| else: | |||
| @@ -42,9 +42,9 @@ class TwitterConfig(Config): | |||
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | |||
| image_embedding = 768 | |||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
| text_embedding = 768 | |||
| max_length = 32 | |||
| @@ -1,7 +1,7 @@ | |||
| from torch import nn | |||
| from transformers import BertModel, BertConfig, BertTokenizer, \ | |||
| BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||
| XLNetModel, XLNetConfig, XLNetTokenizer | |||
| from transformers import BertModel, BertConfig, \ | |||
| RobertaModel, RobertaConfig, \ | |||
| XLNetModel, XLNetConfig | |||
| class TextEncoder(nn.Module): | |||
| @@ -26,15 +26,16 @@ class TextEncoder(nn.Module): | |||
| def get_text_model(config): | |||
| if 'bigbird' in config.text_encoder_model: | |||
| if 'roberta' in config.text_encoder_model: | |||
| if config.pretrained: | |||
| model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||
| model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
| output_hidden_states=True, return_dict=True) | |||
| else: | |||
| model = BigBirdModel(config=BigBirdConfig()) | |||
| model = RobertaModel(config=RobertaConfig()) | |||
| elif 'xlnet' in config.text_encoder_model: | |||
| if config.pretrained: | |||
| model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
| output_hidden_states=True, return_dict=True) | |||
| output_hidden_states=True, return_dict=True) | |||
| else: | |||
| model = XLNetModel(config=XLNetConfig()) | |||
| else: | |||
| @@ -44,4 +45,3 @@ def get_text_model(config): | |||
| else: | |||
| model = BertModel(config=BertConfig()) | |||
| return model | |||