|
1234567891011121314151617181920212223242526272829303132 |
- from transformers import (
- T5TokenizerFast,
- BertTokenizerFast,
- BartTokenizerFast,
- T5ForConditionalGeneration,
- BertForSequenceClassification,
- BartForConditionalGeneration,
- BartForSequenceClassification
- )
-
-
- def auto_model(model_name, output_info):
- if 't5' in model_name.lower():
- model = T5ForConditionalGeneration.from_pretrained(model_name)
- tokenizer = T5TokenizerFast.from_pretrained(model_name, model_max_length=2048)
- model._is_seq2seq = True
- tokenizer._is_seq2seq = True
- elif 'bart' in model_name.lower():
- model = BartForConditionalGeneration.from_pretrained(model_name)
- tokenizer = BartTokenizerFast.from_pretrained(model_name, model_max_length=1024)
- model._is_seq2seq = True
- tokenizer._is_seq2seq = True
- elif 'bert' in model_name.lower():
- class_count = len(output_info.names)
- model = BertForSequenceClassification.from_pretrained(model_name, num_labels=class_count)
- tokenizer = BertTokenizerFast.from_pretrained(model_name, trunction=True)
- model._is_seq2seq = False
- tokenizer._is_seq2seq = False
- else:
- raise NotImplementedError()
-
- return model, tokenizer
|