import torch from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding def generate_dataloader(tokenizer, ds_train, ds_valid_dict, train_bs, valid_bs): if tokenizer._is_seq2seq: col_fn = DataCollatorForSeq2Seq( tokenizer, return_tensors='pt', padding='longest' ) else: col_fn = DataCollatorWithPadding( tokenizer, return_tensors='pt', padding='longest' ) train_loader = torch.utils.data.DataLoader( ds_train, batch_size=train_bs, collate_fn=col_fn, shuffle=True ) valid_loader = { key: torch.utils.data.DataLoader( val, batch_size=valid_bs, collate_fn=col_fn, # shuffle=True ) for key, val in ds_valid_dict.items() } return train_loader, valid_loader def generate_output_preprocess(tokenizer): if tokenizer._is_seq2seq: def preprocess(all_input_ids): return_value = [] for input_ids in all_input_ids: if -100 in input_ids: input_ids = input_ids[:input_ids.index(-100)] return_value.append(tokenizer.decode(input_ids, skip_special_tokens=True)) return return_value return preprocess else: return lambda x: x # identity function