|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- from datasets import load_dataset
- from evaluate import load
- import numpy as np
- from _utils import prefix_dict_keys
-
- from .my_label import MyClassLabel, MyRegresionLabel
-
- class GLUEHelperBase:
- def __init__(self, base_name, load_names):
- self.base_name = base_name
- self.datasets = {}
-
- for name in load_names:
- self.__load_dataset(name)
-
- def __load_dataset(self, name):
- self.datasets[name] = load_dataset(self.base_name, name)
-
- @property
- def keys(self):
- return list(self.datasets.keys())
-
- def get_task_input(self, task_name):
- return_value = list(self.datasets[task_name]['train'].column_names)
- return_value.remove('label')
- return_value.remove('idx')
- return return_value
-
- def get_task_train_key(self, task_name):
- return 'train'
-
- def get_task_validation_key(self, task_name):
- return 'validation',
-
- def get_dataset(self, task_name):
- if task_name not in self.datasets:
- self.__load_dataset(task_name)
- return self.datasets[task_name]
-
- def generate_compute_metrics(self, task_name, text2text: bool):
- task_output = self.get_task_output(task_name)
- glue_metric = load(self.base_name, task_name)
-
- def compute_metrics(y_pred, y_true):
- if text2text:
- y_pred = task_output.str2int(y_pred)
- y_true = task_output.str2int(y_true)
- if None in y_pred:
- y_pred = [0, 1]
- y_true = [1, 0]
-
- glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
- glue_metrics['mean'] = np.mean(list(glue_metrics.values()))
-
- return glue_metrics
-
- return compute_metrics
-
-
- class GLUEHelper(GLUEHelperBase):
- def __init__(self, load_names=None):
- if load_names is None:
- load_names = self.__class__.get_task_names()
-
- super().__init__('glue', load_names)
-
- @property
- def keys(self):
- return list(self.datasets.keys())
-
- @staticmethod
- def get_task_names():
- return [
- 'cola', 'sst2', 'mrpc', 'qqp',
- 'stsb',
- 'mnli', # different validation matched/mismatched
- 'qnli', 'rte', 'wnli',
- # 'ax' not have a train section
- ]
-
- @staticmethod
- def get_task_output(task_name):
- if task_name == 'stsb':
- return MyRegresionLabel()
- names = {
- 'cola': ['unacceptable', 'acceptable'],
- 'sst2': ['negative', 'positive'],
- 'mrpc': ['not_equivalent', 'equivalent'],
- 'qqp': ['not_duplicate', 'duplicate'],
- 'mnli': ['entailment', 'neutral', 'contradiction'],
- 'qnli': ['entailment', 'not_entailment'],
- 'rte': ['entailment', 'not_entailment'],
- 'wnli': ['not_entailment', 'entailment']
- }[task_name]
- return MyClassLabel(names)
-
- def get_task_validation_key(self, task_name):
- if task_name == 'mnli':
- return 'validation_matched', 'validation_mismatched'
- return 'validation',
-
-
- class SuperGLUEHelper(GLUEHelperBase):
- def __init__(self, load_names=None):
- if load_names is None:
- load_names = self.__class__.get_task_names()
-
- super().__init__('super_glue', load_names)
-
- def get_task_input(self, task_name):
- map_dict = {
- "wic": ("sentence1", "sentence2"),
- "wsc.fixed": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
- "multirc": ("question", "answer", "paragraph"),
- "copa": ('choice1', 'choice2', 'premise', 'question'),
- "boolq": ("question", "passage") # save question from truncing
- }
- if task_name in map_dict:
- return map_dict[task_name]
- return super().get_task_input(task_name)
-
- @staticmethod
- def get_task_output(task_name):
- names = {
- 'boolq': ['False', 'True'],
- 'cb': ['entailment', 'contradiction', 'neutral'],
- 'copa': ['choice1', 'choice2'],
- 'multirc': ['False', 'True'],
- 'rte': ['entailment', 'not_entailment'],
- 'wic': ['False', 'True'],
- 'wsc.fixed': ['False', 'True']
- }[task_name]
- return MyClassLabel(names)
-
- @staticmethod
- def get_task_names():
- return [
- 'boolq', 'cb', 'copa', 'multirc',
- # 'record', an span problem
- 'rte', 'wic', 'wsc.fixed',
- # 'axb', 'axg' no training
- ]
-
- def generate_compute_metrics(self, task_name, text2text: bool):
- if task_name in ['multirc', 'record']:
- task_output = self.get_task_output(task_name)
- glue_metric = load(self.base_name, task_name)
- all_idx = self.datasets[task_name]['validation']['idx']
-
- if task_name == 'multirc':
- def compute_metrics(y_pred, y_true):
- y_pred = task_output.str2int(y_pred)
- assert len(all_idx) == len(y_pred)
- if None in y_pred:
- glue_metrics = {'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0}
- else:
- y_pred = [
- {
- 'prediction': y_pred_item,
- 'idx': idx
- } for (y_pred_item, idx) in zip(y_pred, all_idx)
- ]
- y_true = task_output.str2int(y_true)
- glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
-
- glue_metrics['mean'] = np.mean([glue_metrics['exact_match'], glue_metrics['f1_a']])
-
- return glue_metrics
-
- elif task_name == 'record':
- def compute_metrics(y_pred, y_true):
- assert len(all_idx) == len(y_pred)
- if None in y_pred:
- glue_metrics = {'exact_match': 0.0, 'f1': 0.0}
- else:
- y_pred = [
- {
- 'prediction': y_pred_item,
- 'idx': idx
- } for (y_pred_item, idx) in zip(y_pred, all_idx)
- ]
- glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
- glue_metrics['mean'] = np.mean(list(glue_metrics.values()))
-
- return glue_metrics
-
- return compute_metrics
-
- else:
- return super().generate_compute_metrics(task_name, text2text)
|