from datasets import DatasetDict from .glue_helper import GLUEHelper, SuperGLUEHelper class AutoLoad: def __init__(self, tokenizer, n_prefix_token=0, lazy_load=True): self.tokenizer = tokenizer self.n_prefix_token = n_prefix_token # self.lowercase = lowercase self.post_tokenizer_map = { 'input_ids': 0, 'attention_mask': 1, 'token_type_ids': 0 } load_names = [] if lazy_load else None self.glue_helper = GLUEHelper(load_names) self.superglue_helper = SuperGLUEHelper(load_names) @property def _is_bert(self): return 'bert' in self.tokenizer.name_or_path.lower() def __output_type(self): return_value = [ 'input_ids', 'attention_mask', 'labels' ] if self._is_bert: return return_value + ['token_type_ids'] return return_value def _add_prefix(self, tokenizer_out): if self.n_prefix_token == 0: return tokenizer_out for special_key, pad_val in self.post_tokenizer_map.items(): if special_key in tokenizer_out: for batch_item in tokenizer_out[special_key]: batch_item[:0] = ([pad_val] * self.n_prefix_token) return tokenizer_out def map_dataset(self, dataset, input_info, output_info, task_name): def preprocess(input_dict_row): return_value = {} if task_name == 'wic': word = input_dict_row['word'] sent1 = input_dict_row['sentence1'] sent2 = input_dict_row['sentence2'] slice1 = slice(input_dict_row['start1'], input_dict_row['end1']) slice2 = slice(input_dict_row['start2'], input_dict_row['end2']) anotate_word = lambda _sent, _slice: _sent[:_slice.start] + "** " + _sent[_slice] + " **" + _sent[_slice.stop:] input_dict_row['sentence1'] = anotate_word(sent1, slice1) input_dict_row['sentence2'] = anotate_word(sent2, slice2) return_value['sentence1'] = input_dict_row['sentence1'] return_value['sentence2'] = input_dict_row['sentence2'] if len(input_info) == 1: return_value['merged'] = input_dict_row[input_info[0]] else: return_value['merged'] = "".join(f"{key}: {input_dict_row[key]} " for key in input_info) return return_value def create_input(input_dict_rows): if self._is_bert: if len(input_info) < 3: generator = (input_dict_rows[input_name] for input_name in input_info) else: generator = [input_dict_rows['merged']] tokenizer_out = self.tokenizer( *generator, truncation=True, max_length=self.tokenizer.model_max_length - self.n_prefix_token ) else: # t5 or bart multi tokens tokenizer_out = self.tokenizer(input_dict_rows['merged']) return self._add_prefix(tokenizer_out) def create_output(input_dict): if self.tokenizer._is_seq2seq: tokens = self.tokenizer(output_info.int2str(input_dict['label'])) return tokens.input_ids else: return input_dict['label'] def map_function(input_dict): return { **create_input(input_dict), 'labels': create_output(input_dict) } dataset = dataset.map(preprocess) # pass all as one batch dataset = dataset.map(map_function, batched=True) # pass all as one batch dataset.set_format(type='torch', columns=self.__output_type()) return dataset def get_glue(self, category, task_name): glue_agent = { 'glue': self.glue_helper, 'superglue': self.superglue_helper }[category] dataset = glue_agent.get_dataset(task_name) train_ds = dataset[glue_agent.get_task_train_key(task_name)] valid_ds_keys = glue_agent.get_task_validation_key(task_name) valid_ds_dict = DatasetDict({ key: dataset[key] for key in valid_ds_keys }) kwargs = { 'input_info': glue_agent.get_task_input(task_name), 'output_info': glue_agent.get_task_output(task_name), 'task_name': task_name } return { 'name': f'{category}-{task_name}', 'train': self.map_dataset(train_ds, **kwargs), 'valid_dict': self.map_dataset(valid_ds_dict, **kwargs), 'compute_metrics': glue_agent.generate_compute_metrics(task_name, text2text=self.tokenizer._is_seq2seq) } def get_and_map(self, task_name): category, ds_name = task_name.split(':') if category in ['glue', 'superglue']: return self.get_glue(category, ds_name) raise Exception("not implented") @staticmethod def get_task_output(full_task_name): category, task_name = full_task_name.split(':') if category in ['glue', 'superglue']: selected_helper = { 'glue': GLUEHelper, 'superglue': SuperGLUEHelper }[category] return selected_helper.get_task_output(task_name)