@@ -2,7 +2,7 @@ import torch | |||
from torch.utils.data import Dataset | |||
import albumentations as A | |||
from transformers import ViTFeatureExtractor | |||
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||
def get_transforms(config): | |||
@@ -15,8 +15,8 @@ def get_transforms(config): | |||
def get_tokenizer(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||
if 'roberta' in config.text_encoder_model: | |||
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||
elif 'xlnet' in config.text_encoder_model: | |||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | |||
else: |
@@ -42,9 +42,9 @@ class TwitterConfig(Config): | |||
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | |||
image_embedding = 768 | |||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | |||
text_embedding = 768 | |||
max_length = 32 |
@@ -1,7 +1,7 @@ | |||
from torch import nn | |||
from transformers import BertModel, BertConfig, BertTokenizer, \ | |||
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||
XLNetModel, XLNetConfig, XLNetTokenizer | |||
from transformers import BertModel, BertConfig, \ | |||
RobertaModel, RobertaConfig, \ | |||
XLNetModel, XLNetConfig | |||
class TextEncoder(nn.Module): | |||
@@ -26,15 +26,16 @@ class TextEncoder(nn.Module): | |||
def get_text_model(config): | |||
if 'bigbird' in config.text_encoder_model: | |||
if 'roberta' in config.text_encoder_model: | |||
if config.pretrained: | |||
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||
model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
output_hidden_states=True, return_dict=True) | |||
else: | |||
model = BigBirdModel(config=BigBirdConfig()) | |||
model = RobertaModel(config=RobertaConfig()) | |||
elif 'xlnet' in config.text_encoder_model: | |||
if config.pretrained: | |||
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||
output_hidden_states=True, return_dict=True) | |||
output_hidden_states=True, return_dict=True) | |||
else: | |||
model = XLNetModel(config=XLNetConfig()) | |||
else: | |||
@@ -44,4 +45,3 @@ def get_text_model(config): | |||
else: | |||
model = BertModel(config=BertConfig()) | |||
return model | |||