You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 1.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. from transformers import DataCollatorForSeq2Seq, DataCollatorWithPadding
  3. def generate_dataloader(tokenizer, ds_train, ds_valid_dict, train_bs, valid_bs):
  4. if tokenizer._is_seq2seq:
  5. col_fn = DataCollatorForSeq2Seq(
  6. tokenizer, return_tensors='pt', padding='longest'
  7. )
  8. else:
  9. col_fn = DataCollatorWithPadding(
  10. tokenizer, return_tensors='pt', padding='longest'
  11. )
  12. train_loader = torch.utils.data.DataLoader(
  13. ds_train,
  14. batch_size=train_bs,
  15. collate_fn=col_fn,
  16. shuffle=True
  17. )
  18. valid_loader = {
  19. key: torch.utils.data.DataLoader(
  20. val,
  21. batch_size=valid_bs,
  22. collate_fn=col_fn,
  23. # shuffle=True
  24. )
  25. for key, val in ds_valid_dict.items()
  26. }
  27. return train_loader, valid_loader
  28. def generate_output_preprocess(tokenizer):
  29. if tokenizer._is_seq2seq:
  30. def preprocess(all_input_ids):
  31. return_value = []
  32. for input_ids in all_input_ids:
  33. if -100 in input_ids:
  34. input_ids = input_ids[:input_ids.index(-100)]
  35. return_value.append(tokenizer.decode(input_ids, skip_special_tokens=True))
  36. return return_value
  37. return preprocess
  38. else:
  39. return lambda x: x # identity function