|
1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import torch
- from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding
-
- def generate_dataloader(tokenizer, ds_train, ds_valid_dict, train_bs, valid_bs):
- if tokenizer._is_seq2seq:
- col_fn = DataCollatorForSeq2Seq(
- tokenizer, return_tensors='pt', padding='longest'
- )
- else:
- col_fn = DataCollatorWithPadding(
- tokenizer, return_tensors='pt', padding='longest'
- )
-
- train_loader = torch.utils.data.DataLoader(
- ds_train,
- batch_size=train_bs,
- collate_fn=col_fn,
- shuffle=True
- )
-
- valid_loader = {
- key: torch.utils.data.DataLoader(
- val,
- batch_size=valid_bs,
- collate_fn=col_fn,
- # shuffle=True
- )
- for key, val in ds_valid_dict.items()
- }
-
- return train_loader, valid_loader
-
- def generate_output_preprocess(tokenizer):
- if tokenizer._is_seq2seq:
- def preprocess(all_input_ids):
- return_value = []
- for input_ids in all_input_ids:
- if -100 in input_ids:
- input_ids = input_ids[:input_ids.index(-100)]
- return_value.append(tokenizer.decode(input_ids, skip_special_tokens=True))
- return return_value
- return preprocess
- else:
- return lambda x: x # identity function
|