from transformers import ( T5TokenizerFast, BertTokenizerFast, BartTokenizerFast, T5ForConditionalGeneration, BertForSequenceClassification, BartForConditionalGeneration, BartForSequenceClassification ) def auto_model(model_name, output_info): if 't5' in model_name.lower(): model = T5ForConditionalGeneration.from_pretrained(model_name) tokenizer = T5TokenizerFast.from_pretrained(model_name, model_max_length=2048) model._is_seq2seq = True tokenizer._is_seq2seq = True elif 'bart' in model_name.lower(): model = BartForConditionalGeneration.from_pretrained(model_name) tokenizer = BartTokenizerFast.from_pretrained(model_name, model_max_length=1024) model._is_seq2seq = True tokenizer._is_seq2seq = True elif 'bert' in model_name.lower(): class_count = len(output_info.names) model = BertForSequenceClassification.from_pretrained(model_name, num_labels=class_count) tokenizer = BertTokenizerFast.from_pretrained(model_name, trunction=True) model._is_seq2seq = False tokenizer._is_seq2seq = False else: raise NotImplementedError() return model, tokenizer