|
1234567891011121314 |
- from transformers import BertTokenizerFast, DataCollatorWithPadding
-
-
- class TokenizerMan:
- def __init__(self, tokenizer_kind: str, pretrained_name: str):
- if tokenizer_kind == 'bert':
- self.tokenizer = BertTokenizerFast.from_pretrained(pretrained_name)
- else:
- raise Exception('Not implemented!')
-
- def get_col_fn(self):
- return DataCollatorWithPadding(
- self.tokenizer, return_tensors='pt', padding='longest'
- )
|