| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| import albumentations as A | import albumentations as A | ||||
| from transformers import ViTFeatureExtractor | from transformers import ViTFeatureExtractor | ||||
| from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer | |||||
| from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer | |||||
| def get_transforms(config): | def get_transforms(config): | ||||
| def get_tokenizer(config): | def get_tokenizer(config): | ||||
| if 'bigbird' in config.text_encoder_model: | |||||
| tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer) | |||||
| if 'roberta' in config.text_encoder_model: | |||||
| tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer) | |||||
| elif 'xlnet' in config.text_encoder_model: | elif 'xlnet' in config.text_encoder_model: | ||||
| tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) | ||||
| else: | else: |
| image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' | ||||
| image_embedding = 768 | image_embedding = 768 | ||||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||||
| text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||||
| # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | ||||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased" | |||||
| text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base" | |||||
| # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" | ||||
| text_embedding = 768 | text_embedding = 768 | ||||
| max_length = 32 | max_length = 32 |
| from torch import nn | from torch import nn | ||||
| from transformers import BertModel, BertConfig, BertTokenizer, \ | |||||
| BigBirdModel, BigBirdConfig, BigBirdTokenizer, \ | |||||
| XLNetModel, XLNetConfig, XLNetTokenizer | |||||
| from transformers import BertModel, BertConfig, \ | |||||
| RobertaModel, RobertaConfig, \ | |||||
| XLNetModel, XLNetConfig | |||||
| class TextEncoder(nn.Module): | class TextEncoder(nn.Module): | ||||
| def get_text_model(config): | def get_text_model(config): | ||||
| if 'bigbird' in config.text_encoder_model: | |||||
| if 'roberta' in config.text_encoder_model: | |||||
| if config.pretrained: | if config.pretrained: | ||||
| model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2) | |||||
| model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False, | |||||
| output_hidden_states=True, return_dict=True) | |||||
| else: | else: | ||||
| model = BigBirdModel(config=BigBirdConfig()) | |||||
| model = RobertaModel(config=RobertaConfig()) | |||||
| elif 'xlnet' in config.text_encoder_model: | elif 'xlnet' in config.text_encoder_model: | ||||
| if config.pretrained: | if config.pretrained: | ||||
| model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, | ||||
| output_hidden_states=True, return_dict=True) | |||||
| output_hidden_states=True, return_dict=True) | |||||
| else: | else: | ||||
| model = XLNetModel(config=XLNetConfig()) | model = XLNetModel(config=XLNetConfig()) | ||||
| else: | else: | ||||
| else: | else: | ||||
| model = BertModel(config=BertConfig()) | model = BertModel(config=BertConfig()) | ||||
| return model | return model | ||||