Browse Source

roberta added

master
Faeze 1 year ago
parent
commit
9092e12d67
3 changed files with 13 additions and 13 deletions
  1. 3
    3
      data/data_loader.py
  2. 2
    2
      data/twitter/config.py
  3. 8
    8
      text.py

+ 3
- 3
data/data_loader.py View File

@@ -2,7 +2,7 @@ import torch
from torch.utils.data import Dataset
import albumentations as A
from transformers import ViTFeatureExtractor
from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer


def get_transforms(config):
@@ -15,8 +15,8 @@ def get_transforms(config):


def get_tokenizer(config):
if 'bigbird' in config.text_encoder_model:
tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
if 'roberta' in config.text_encoder_model:
tokenizer = RobertaTokenizer.from_pretrained(config.text_tokenizer)
elif 'xlnet' in config.text_encoder_model:
tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
else:

+ 2
- 2
data/twitter/config.py View File

@@ -42,9 +42,9 @@ class TwitterConfig(Config):

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/vit-base-patch16-224'
image_embedding = 768
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base"
# text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/bert-base-uncased"
text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/huggingface/roberta-base"
# text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
text_embedding = 768
max_length = 32

+ 8
- 8
text.py View File

@@ -1,7 +1,7 @@
from torch import nn
from transformers import BertModel, BertConfig, BertTokenizer, \
BigBirdModel, BigBirdConfig, BigBirdTokenizer, \
XLNetModel, XLNetConfig, XLNetTokenizer
from transformers import BertModel, BertConfig, \
RobertaModel, RobertaConfig, \
XLNetModel, XLNetConfig


class TextEncoder(nn.Module):
@@ -26,15 +26,16 @@ class TextEncoder(nn.Module):


def get_text_model(config):
if 'bigbird' in config.text_encoder_model:
if 'roberta' in config.text_encoder_model:
if config.pretrained:
model = BigBirdModel.from_pretrained(config.text_encoder_model, block_size=16, num_random_blocks=2)
model = RobertaModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
else:
model = BigBirdModel(config=BigBirdConfig())
model = RobertaModel(config=RobertaConfig())
elif 'xlnet' in config.text_encoder_model:
if config.pretrained:
model = XLNetModel.from_pretrained(config.text_encoder_model, output_attentions=False,
output_hidden_states=True, return_dict=True)
output_hidden_states=True, return_dict=True)
else:
model = XLNetModel(config=XLNetConfig())
else:
@@ -44,4 +45,3 @@ def get_text_model(config):
else:
model = BertModel(config=BertConfig())
return model


Loading…
Cancel
Save