You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

autoload.py 5.5KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from datasets import DatasetDict
  2. from .glue_helper import GLUEHelper, SuperGLUEHelper
  3. class AutoLoad:
  4. def __init__(self, tokenizer, n_prefix_token=0, lazy_load=True):
  5. self.tokenizer = tokenizer
  6. self.n_prefix_token = n_prefix_token
  7. # self.lowercase = lowercase
  8. self.post_tokenizer_map = {
  9. 'input_ids': 0,
  10. 'attention_mask': 1,
  11. 'token_type_ids': 0
  12. }
  13. load_names = [] if lazy_load else None
  14. self.glue_helper = GLUEHelper(load_names)
  15. self.superglue_helper = SuperGLUEHelper(load_names)
  16. @property
  17. def _is_bert(self):
  18. return 'bert' in self.tokenizer.name_or_path.lower()
  19. def __output_type(self):
  20. return_value = [
  21. 'input_ids', 'attention_mask', 'labels'
  22. ]
  23. if self._is_bert:
  24. return return_value + ['token_type_ids']
  25. return return_value
  26. def _add_prefix(self, tokenizer_out):
  27. if self.n_prefix_token == 0:
  28. return tokenizer_out
  29. for special_key, pad_val in self.post_tokenizer_map.items():
  30. if special_key in tokenizer_out:
  31. for batch_item in tokenizer_out[special_key]:
  32. batch_item[:0] = ([pad_val] * self.n_prefix_token)
  33. return tokenizer_out
  34. def map_dataset(self, dataset, input_info, output_info, task_name):
  35. def preprocess(input_dict_row):
  36. return_value = {}
  37. if task_name == 'wic':
  38. word = input_dict_row['word']
  39. sent1 = input_dict_row['sentence1']
  40. sent2 = input_dict_row['sentence2']
  41. slice1 = slice(input_dict_row['start1'], input_dict_row['end1'])
  42. slice2 = slice(input_dict_row['start2'], input_dict_row['end2'])
  43. anotate_word = lambda _sent, _slice: _sent[:_slice.start] + "** " + _sent[_slice] + " **" + _sent[_slice.stop:]
  44. input_dict_row['sentence1'] = anotate_word(sent1, slice1)
  45. input_dict_row['sentence2'] = anotate_word(sent2, slice2)
  46. return_value['sentence1'] = input_dict_row['sentence1']
  47. return_value['sentence2'] = input_dict_row['sentence2']
  48. if len(input_info) == 1:
  49. return_value['merged'] = input_dict_row[input_info[0]]
  50. else:
  51. return_value['merged'] = "".join(f"{key}: {input_dict_row[key]} " for key in input_info)
  52. return return_value
  53. def create_input(input_dict_rows):
  54. if self._is_bert:
  55. if len(input_info) < 3:
  56. generator = (input_dict_rows[input_name] for input_name in input_info)
  57. else:
  58. generator = [input_dict_rows['merged']]
  59. tokenizer_out = self.tokenizer(
  60. *generator,
  61. truncation=True,
  62. max_length=self.tokenizer.model_max_length - self.n_prefix_token
  63. )
  64. else: # t5 or bart multi tokens
  65. tokenizer_out = self.tokenizer(input_dict_rows['merged'])
  66. return self._add_prefix(tokenizer_out)
  67. def create_output(input_dict):
  68. if self.tokenizer._is_seq2seq:
  69. tokens = self.tokenizer(output_info.int2str(input_dict['label']))
  70. return tokens.input_ids
  71. else:
  72. return input_dict['label']
  73. def map_function(input_dict):
  74. return {
  75. **create_input(input_dict),
  76. 'labels': create_output(input_dict)
  77. }
  78. dataset = dataset.map(preprocess) # pass all as one batch
  79. dataset = dataset.map(map_function, batched=True) # pass all as one batch
  80. dataset.set_format(type='torch', columns=self.__output_type())
  81. return dataset
  82. def get_glue(self, category, task_name):
  83. glue_agent = {
  84. 'glue': self.glue_helper,
  85. 'superglue': self.superglue_helper
  86. }[category]
  87. dataset = glue_agent.get_dataset(task_name)
  88. train_ds = dataset[glue_agent.get_task_train_key(task_name)]
  89. valid_ds_keys = glue_agent.get_task_validation_key(task_name)
  90. valid_ds_dict = DatasetDict({
  91. key: dataset[key]
  92. for key in valid_ds_keys
  93. })
  94. kwargs = {
  95. 'input_info': glue_agent.get_task_input(task_name),
  96. 'output_info': glue_agent.get_task_output(task_name),
  97. 'task_name': task_name
  98. }
  99. return {
  100. 'name': f'{category}-{task_name}',
  101. 'train': self.map_dataset(train_ds, **kwargs),
  102. 'valid_dict': self.map_dataset(valid_ds_dict, **kwargs),
  103. 'compute_metrics': glue_agent.generate_compute_metrics(task_name, text2text=self.tokenizer._is_seq2seq)
  104. }
  105. def get_and_map(self, task_name):
  106. category, ds_name = task_name.split(':')
  107. if category in ['glue', 'superglue']:
  108. return self.get_glue(category, ds_name)
  109. raise Exception("not implented")
  110. @staticmethod
  111. def get_task_output(full_task_name):
  112. category, task_name = full_task_name.split(':')
  113. if category in ['glue', 'superglue']:
  114. selected_helper = {
  115. 'glue': GLUEHelper,
  116. 'superglue': SuperGLUEHelper
  117. }[category]
  118. return selected_helper.get_task_output(task_name)