You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auto_model.py 1.2KB

3 months ago
1234567891011121314151617181920212223242526272829303132
  1. from transformers import (
  2. T5TokenizerFast,
  3. BertTokenizerFast,
  4. BartTokenizerFast,
  5. T5ForConditionalGeneration,
  6. BertForSequenceClassification,
  7. BartForConditionalGeneration,
  8. BartForSequenceClassification
  9. )
  10. def auto_model(model_name, output_info):
  11. if 't5' in model_name.lower():
  12. model = T5ForConditionalGeneration.from_pretrained(model_name)
  13. tokenizer = T5TokenizerFast.from_pretrained(model_name, model_max_length=2048)
  14. model._is_seq2seq = True
  15. tokenizer._is_seq2seq = True
  16. elif 'bart' in model_name.lower():
  17. model = BartForConditionalGeneration.from_pretrained(model_name)
  18. tokenizer = BartTokenizerFast.from_pretrained(model_name, model_max_length=1024)
  19. model._is_seq2seq = True
  20. tokenizer._is_seq2seq = True
  21. elif 'bert' in model_name.lower():
  22. class_count = len(output_info.names)
  23. model = BertForSequenceClassification.from_pretrained(model_name, num_labels=class_count)
  24. tokenizer = BertTokenizerFast.from_pretrained(model_name, trunction=True)
  25. model._is_seq2seq = False
  26. tokenizer._is_seq2seq = False
  27. else:
  28. raise NotImplementedError()
  29. return model, tokenizer