Browse Source

roberta added

master
Faeze 1 year ago
parent
commit
9092e12d67
3 changed files with 13 additions and 13 deletions
  1. 3
    3
      data/data_loader.py
  2. 2
    2
      data/twitter/config.py
  3. 8
    8
      text.py

+ 3
- 3
data/data_loader.py View File

from torch.utils.data import Dataset from torch.utils.data import Dataset
import albumentations as A import albumentations as A
from transformers import ViTFeatureExtractor from transformers import ViTFeatureExtractor
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer




def get_transforms(config): def get_transforms(config):




def get_tokenizer(config): def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
if 'roberta' in config.text_encoder_model:
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model: elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer) tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else: else:

+ 2
- 2
data/twitter/config.py View File



image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224' image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224'
image_embedding = 768 image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased" # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768 text_embedding = 768
max_length = 32 max_length = 32

+ 8
- 8
text.py View File

from torch import nn from torch import nn
from transformers import BertModel, BertConfig, BertTokenizer, \
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \
XLNetModel, XLNetConfig, XLNetTokenizer
from transformers import BertModel, BertConfig, \
RobertaModel, RobertaConfig, \
XLNetModel, XLNetConfig




class TextEncoder(nn.Module): class TextEncoder(nn.Module):




def get_text_model(config): def get_text_model(config):
if 'bigbird' in config.text_encoder_model:
if 'roberta' in config.text_encoder_model:
if config.pretrained: if config.pretrained:
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2)
model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
else: else:
model = BigBirdModel(config=BigBirdConfig())
model = RobertaModel(config=RobertaConfig())
elif 'xlnet' in config.text_encoder_model: elif 'xlnet' in config.text_encoder_model:
if config.pretrained: if config.pretrained:
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False, model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
output_hidden_states=True, return_dict=True)
else: else:
model = XLNetModel(config=XLNetConfig()) model = XLNetModel(config=XLNetConfig())
else: else:
else: else:
model = BertModel(config=BertConfig()) model = BertModel(config=BertConfig())
return model return model


Loading…
Cancel
Save