|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from datasets import DatasetDict
- from .glue_helper import GLUEHelper, SuperGLUEHelper
-
-
- class AutoLoad:
- def __init__(self, tokenizer, n_prefix_token=0, lazy_load=True):
- self.tokenizer = tokenizer
- self.n_prefix_token = n_prefix_token
- # self.lowercase = lowercase
-
- self.post_tokenizer_map = {
- 'input_ids': 0,
- 'attention_mask': 1,
- 'token_type_ids': 0
- }
-
- load_names = [] if lazy_load else None
- self.glue_helper = GLUEHelper(load_names)
- self.superglue_helper = SuperGLUEHelper(load_names)
-
- @property
- def _is_bert(self):
- return 'bert' in self.tokenizer.name_or_path.lower()
-
- def __output_type(self):
- return_value = [
- 'input_ids', 'attention_mask', 'labels'
- ]
-
- if self._is_bert:
- return return_value + ['token_type_ids']
-
- return return_value
-
- def _add_prefix(self, tokenizer_out):
- if self.n_prefix_token == 0:
- return tokenizer_out
- for special_key, pad_val in self.post_tokenizer_map.items():
- if special_key in tokenizer_out:
- for batch_item in tokenizer_out[special_key]:
- batch_item[:0] = ([pad_val] * self.n_prefix_token)
- return tokenizer_out
-
- def map_dataset(self, dataset, input_info, output_info, task_name):
- def preprocess(input_dict_row):
- return_value = {}
- if task_name == 'wic':
- word = input_dict_row['word']
- sent1 = input_dict_row['sentence1']
- sent2 = input_dict_row['sentence2']
- slice1 = slice(input_dict_row['start1'], input_dict_row['end1'])
- slice2 = slice(input_dict_row['start2'], input_dict_row['end2'])
-
- anotate_word = lambda _sent, _slice: _sent[:_slice.start] + "** " + _sent[_slice] + " **" + _sent[_slice.stop:]
- input_dict_row['sentence1'] = anotate_word(sent1, slice1)
- input_dict_row['sentence2'] = anotate_word(sent2, slice2)
-
- return_value['sentence1'] = input_dict_row['sentence1']
- return_value['sentence2'] = input_dict_row['sentence2']
-
- if len(input_info) == 1:
- return_value['merged'] = input_dict_row[input_info[0]]
- else:
- return_value['merged'] = "".join(f"{key}: {input_dict_row[key]} " for key in input_info)
- return return_value
-
- def create_input(input_dict_rows):
- if self._is_bert:
- if len(input_info) < 3:
- generator = (input_dict_rows[input_name] for input_name in input_info)
- else:
- generator = [input_dict_rows['merged']]
- tokenizer_out = self.tokenizer(
- *generator,
- truncation=True,
- max_length=self.tokenizer.model_max_length - self.n_prefix_token
- )
- else: # t5 or bart multi tokens
- tokenizer_out = self.tokenizer(input_dict_rows['merged'])
-
- return self._add_prefix(tokenizer_out)
-
- def create_output(input_dict):
- if self.tokenizer._is_seq2seq:
- tokens = self.tokenizer(output_info.int2str(input_dict['label']))
- return tokens.input_ids
- else:
- return input_dict['label']
-
- def map_function(input_dict):
- return {
- **create_input(input_dict),
- 'labels': create_output(input_dict)
- }
- dataset = dataset.map(preprocess) # pass all as one batch
- dataset = dataset.map(map_function, batched=True) # pass all as one batch
- dataset.set_format(type='torch', columns=self.__output_type())
- return dataset
-
- def get_glue(self, category, task_name):
- glue_agent = {
- 'glue': self.glue_helper,
- 'superglue': self.superglue_helper
- }[category]
-
- dataset = glue_agent.get_dataset(task_name)
-
- train_ds = dataset[glue_agent.get_task_train_key(task_name)]
- valid_ds_keys = glue_agent.get_task_validation_key(task_name)
- valid_ds_dict = DatasetDict({
- key: dataset[key]
- for key in valid_ds_keys
- })
-
- kwargs = {
- 'input_info': glue_agent.get_task_input(task_name),
- 'output_info': glue_agent.get_task_output(task_name),
- 'task_name': task_name
- }
-
- return {
- 'name': f'{category}-{task_name}',
- 'train': self.map_dataset(train_ds, **kwargs),
- 'valid_dict': self.map_dataset(valid_ds_dict, **kwargs),
- 'compute_metrics': glue_agent.generate_compute_metrics(task_name, text2text=self.tokenizer._is_seq2seq)
- }
-
- def get_and_map(self, task_name):
- category, ds_name = task_name.split(':')
- if category in ['glue', 'superglue']:
- return self.get_glue(category, ds_name)
-
- raise Exception("not implented")
-
- @staticmethod
- def get_task_output(full_task_name):
- category, task_name = full_task_name.split(':')
- if category in ['glue', 'superglue']:
- selected_helper = {
- 'glue': GLUEHelper,
- 'superglue': SuperGLUEHelper
- }[category]
- return selected_helper.get_task_output(task_name)
-
|