from datasets import load_dataset from evaluate import load import numpy as np from _utils import prefix_dict_keys from .my_label import MyClassLabel, MyRegresionLabel class GLUEHelperBase: def __init__(self, base_name, load_names): self.base_name = base_name self.datasets = {} for name in load_names: self.__load_dataset(name) def __load_dataset(self, name): self.datasets[name] = load_dataset(self.base_name, name) @property def keys(self): return list(self.datasets.keys()) def get_task_input(self, task_name): return_value = list(self.datasets[task_name]['train'].column_names) return_value.remove('label') return_value.remove('idx') return return_value def get_task_train_key(self, task_name): return 'train' def get_task_validation_key(self, task_name): return 'validation', def get_dataset(self, task_name): if task_name not in self.datasets: self.__load_dataset(task_name) return self.datasets[task_name] def generate_compute_metrics(self, task_name, text2text: bool): task_output = self.get_task_output(task_name) glue_metric = load(self.base_name, task_name) def compute_metrics(y_pred, y_true): if text2text: y_pred = task_output.str2int(y_pred) y_true = task_output.str2int(y_true) if None in y_pred: y_pred = [0, 1] y_true = [1, 0] glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true) glue_metrics['mean'] = np.mean(list(glue_metrics.values())) return glue_metrics return compute_metrics class GLUEHelper(GLUEHelperBase): def __init__(self, load_names=None): if load_names is None: load_names = self.__class__.get_task_names() super().__init__('glue', load_names) @property def keys(self): return list(self.datasets.keys()) @staticmethod def get_task_names(): return [ 'cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', # different validation matched/mismatched 'qnli', 'rte', 'wnli', # 'ax' not have a train section ] @staticmethod def get_task_output(task_name): if task_name == 'stsb': return MyRegresionLabel() names = { 'cola': ['unacceptable', 'acceptable'], 'sst2': ['negative', 'positive'], 'mrpc': ['not_equivalent', 'equivalent'], 'qqp': ['not_duplicate', 'duplicate'], 'mnli': ['entailment', 'neutral', 'contradiction'], 'qnli': ['entailment', 'not_entailment'], 'rte': ['entailment', 'not_entailment'], 'wnli': ['not_entailment', 'entailment'] }[task_name] return MyClassLabel(names) def get_task_validation_key(self, task_name): if task_name == 'mnli': return 'validation_matched', 'validation_mismatched' return 'validation', class SuperGLUEHelper(GLUEHelperBase): def __init__(self, load_names=None): if load_names is None: load_names = self.__class__.get_task_names() super().__init__('super_glue', load_names) def get_task_input(self, task_name): map_dict = { "wic": ("sentence1", "sentence2"), "wsc.fixed": ("span1_text", "span1_index", "span2_text", "span2_index", "text"), "multirc": ("question", "answer", "paragraph"), "copa": ('choice1', 'choice2', 'premise', 'question'), "boolq": ("question", "passage") # save question from truncing } if task_name in map_dict: return map_dict[task_name] return super().get_task_input(task_name) @staticmethod def get_task_output(task_name): names = { 'boolq': ['False', 'True'], 'cb': ['entailment', 'contradiction', 'neutral'], 'copa': ['choice1', 'choice2'], 'multirc': ['False', 'True'], 'rte': ['entailment', 'not_entailment'], 'wic': ['False', 'True'], 'wsc.fixed': ['False', 'True'] }[task_name] return MyClassLabel(names) @staticmethod def get_task_names(): return [ 'boolq', 'cb', 'copa', 'multirc', # 'record', an span problem 'rte', 'wic', 'wsc.fixed', # 'axb', 'axg' no training ] def generate_compute_metrics(self, task_name, text2text: bool): if task_name in ['multirc', 'record']: task_output = self.get_task_output(task_name) glue_metric = load(self.base_name, task_name) all_idx = self.datasets[task_name]['validation']['idx'] if task_name == 'multirc': def compute_metrics(y_pred, y_true): y_pred = task_output.str2int(y_pred) assert len(all_idx) == len(y_pred) if None in y_pred: glue_metrics = {'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0} else: y_pred = [ { 'prediction': y_pred_item, 'idx': idx } for (y_pred_item, idx) in zip(y_pred, all_idx) ] y_true = task_output.str2int(y_true) glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true) glue_metrics['mean'] = np.mean([glue_metrics['exact_match'], glue_metrics['f1_a']]) return glue_metrics elif task_name == 'record': def compute_metrics(y_pred, y_true): assert len(all_idx) == len(y_pred) if None in y_pred: glue_metrics = {'exact_match': 0.0, 'f1': 0.0} else: y_pred = [ { 'prediction': y_pred_item, 'idx': idx } for (y_pred_item, idx) in zip(y_pred, all_idx) ] glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true) glue_metrics['mean'] = np.mean(list(glue_metrics.values())) return glue_metrics return compute_metrics else: return super().generate_compute_metrics(task_name, text2text)