from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
import albumentations as A | import albumentations as A | ||||
from transformers import ViTFeatureExtractor | from transformers import ViTFeatureExtractor | ||||
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||||
def get_transforms(config): | def get_transforms(config): | ||||
def get_tokenizer(config): | def get_tokenizer(config): | ||||
if 'bigbird' in config.text_encoder_model: | |||||
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
if 'roberta' in config.text_encoder_model: | |||||
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||||
elif 'xlnet' in config.text_encoder_model: | elif 'xlnet' in config.text_encoder_model: | ||||
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | ||||
else: | else: |
image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | ||||
image_embedding = 768 | image_embedding = 768 | ||||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||||
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||||
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | ||||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||||
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||||
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | ||||
text_embedding = 768 | text_embedding = 768 | ||||
max_length = 32 | max_length = 32 |
from torch import nn | from torch import nn | ||||
from transformers import BertModel, BertConfig, BertTokenizer, \ | |||||
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||||
XLNetModel, XLNetConfig, XLNetTokenizer | |||||
from transformers import BertModel, BertConfig, \ | |||||
RobertaModel, RobertaConfig, \ | |||||
XLNetModel, XLNetConfig | |||||
class TextEncoder(nn.Module): | class TextEncoder(nn.Module): | ||||
def get_text_model(config): | def get_text_model(config): | ||||
if 'bigbird' in config.text_encoder_model: | |||||
if 'roberta' in config.text_encoder_model: | |||||
if config.pretrained: | if config.pretrained: | ||||
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||||
model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
output_hidden_states=True, return_dict=True) | |||||
else: | else: | ||||
model = BigBirdModel(config=BigBirdConfig()) | |||||
model = RobertaModel(config=RobertaConfig()) | |||||
elif 'xlnet' in config.text_encoder_model: | elif 'xlnet' in config.text_encoder_model: | ||||
if config.pretrained: | if config.pretrained: | ||||
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | ||||
output_hidden_states=True, return_dict=True) | |||||
output_hidden_states=True, return_dict=True) | |||||
else: | else: | ||||
model = XLNetModel(config=XLNetConfig()) | model = XLNetModel(config=XLNetConfig()) | ||||
else: | else: | ||||
else: | else: | ||||
model = BertModel(config=BertConfig()) | model = BertModel(config=BertConfig()) | ||||
return model | return model | ||||