You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

glue_helper.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from datasets import load_dataset
  2. from evaluate import load
  3. import numpy as np
  4. from _utils import prefix_dict_keys
  5. from .my_label import MyClassLabel, MyRegresionLabel
  6. class GLUEHelperBase:
  7. def __init__(self, base_name, load_names):
  8. self.base_name = base_name
  9. self.datasets = {}
  10. for name in load_names:
  11. self.__load_dataset(name)
  12. def __load_dataset(self, name):
  13. self.datasets[name] = load_dataset(self.base_name, name)
  14. @property
  15. def keys(self):
  16. return list(self.datasets.keys())
  17. def get_task_input(self, task_name):
  18. return_value = list(self.datasets[task_name]['train'].column_names)
  19. return_value.remove('label')
  20. return_value.remove('idx')
  21. return return_value
  22. def get_task_train_key(self, task_name):
  23. return 'train'
  24. def get_task_validation_key(self, task_name):
  25. return 'validation',
  26. def get_dataset(self, task_name):
  27. if task_name not in self.datasets:
  28. self.__load_dataset(task_name)
  29. return self.datasets[task_name]
  30. def generate_compute_metrics(self, task_name, text2text: bool):
  31. task_output = self.get_task_output(task_name)
  32. glue_metric = load(self.base_name, task_name)
  33. def compute_metrics(y_pred, y_true):
  34. if text2text:
  35. y_pred = task_output.str2int(y_pred)
  36. y_true = task_output.str2int(y_true)
  37. if None in y_pred:
  38. y_pred = [0, 1]
  39. y_true = [1, 0]
  40. glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
  41. glue_metrics['mean'] = np.mean(list(glue_metrics.values()))
  42. return glue_metrics
  43. return compute_metrics
  44. class GLUEHelper(GLUEHelperBase):
  45. def __init__(self, load_names=None):
  46. if load_names is None:
  47. load_names = self.__class__.get_task_names()
  48. super().__init__('glue', load_names)
  49. @property
  50. def keys(self):
  51. return list(self.datasets.keys())
  52. @staticmethod
  53. def get_task_names():
  54. return [
  55. 'cola', 'sst2', 'mrpc', 'qqp',
  56. 'stsb',
  57. 'mnli', # different validation matched/mismatched
  58. 'qnli', 'rte', 'wnli',
  59. # 'ax' not have a train section
  60. ]
  61. @staticmethod
  62. def get_task_output(task_name):
  63. if task_name == 'stsb':
  64. return MyRegresionLabel()
  65. names = {
  66. 'cola': ['unacceptable', 'acceptable'],
  67. 'sst2': ['negative', 'positive'],
  68. 'mrpc': ['not_equivalent', 'equivalent'],
  69. 'qqp': ['not_duplicate', 'duplicate'],
  70. 'mnli': ['entailment', 'neutral', 'contradiction'],
  71. 'qnli': ['entailment', 'not_entailment'],
  72. 'rte': ['entailment', 'not_entailment'],
  73. 'wnli': ['not_entailment', 'entailment']
  74. }[task_name]
  75. return MyClassLabel(names)
  76. def get_task_validation_key(self, task_name):
  77. if task_name == 'mnli':
  78. return 'validation_matched', 'validation_mismatched'
  79. return 'validation',
  80. class SuperGLUEHelper(GLUEHelperBase):
  81. def __init__(self, load_names=None):
  82. if load_names is None:
  83. load_names = self.__class__.get_task_names()
  84. super().__init__('super_glue', load_names)
  85. def get_task_input(self, task_name):
  86. map_dict = {
  87. "wic": ("sentence1", "sentence2"),
  88. "wsc.fixed": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
  89. "multirc": ("question", "answer", "paragraph"),
  90. "copa": ('choice1', 'choice2', 'premise', 'question'),
  91. "boolq": ("question", "passage") # save question from truncing
  92. }
  93. if task_name in map_dict:
  94. return map_dict[task_name]
  95. return super().get_task_input(task_name)
  96. @staticmethod
  97. def get_task_output(task_name):
  98. names = {
  99. 'boolq': ['False', 'True'],
  100. 'cb': ['entailment', 'contradiction', 'neutral'],
  101. 'copa': ['choice1', 'choice2'],
  102. 'multirc': ['False', 'True'],
  103. 'rte': ['entailment', 'not_entailment'],
  104. 'wic': ['False', 'True'],
  105. 'wsc.fixed': ['False', 'True']
  106. }[task_name]
  107. return MyClassLabel(names)
  108. @staticmethod
  109. def get_task_names():
  110. return [
  111. 'boolq', 'cb', 'copa', 'multirc',
  112. # 'record', an span problem
  113. 'rte', 'wic', 'wsc.fixed',
  114. # 'axb', 'axg' no training
  115. ]
  116. def generate_compute_metrics(self, task_name, text2text: bool):
  117. if task_name in ['multirc', 'record']:
  118. task_output = self.get_task_output(task_name)
  119. glue_metric = load(self.base_name, task_name)
  120. all_idx = self.datasets[task_name]['validation']['idx']
  121. if task_name == 'multirc':
  122. def compute_metrics(y_pred, y_true):
  123. y_pred = task_output.str2int(y_pred)
  124. assert len(all_idx) == len(y_pred)
  125. if None in y_pred:
  126. glue_metrics = {'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0}
  127. else:
  128. y_pred = [
  129. {
  130. 'prediction': y_pred_item,
  131. 'idx': idx
  132. } for (y_pred_item, idx) in zip(y_pred, all_idx)
  133. ]
  134. y_true = task_output.str2int(y_true)
  135. glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
  136. glue_metrics['mean'] = np.mean([glue_metrics['exact_match'], glue_metrics['f1_a']])
  137. return glue_metrics
  138. elif task_name == 'record':
  139. def compute_metrics(y_pred, y_true):
  140. assert len(all_idx) == len(y_pred)
  141. if None in y_pred:
  142. glue_metrics = {'exact_match': 0.0, 'f1': 0.0}
  143. else:
  144. y_pred = [
  145. {
  146. 'prediction': y_pred_item,
  147. 'idx': idx
  148. } for (y_pred_item, idx) in zip(y_pred, all_idx)
  149. ]
  150. glue_metrics = glue_metric.compute(predictions=y_pred, references=y_true)
  151. glue_metrics['mean'] = np.mean(list(glue_metrics.values()))
  152. return glue_metrics
  153. return compute_metrics
  154. else:
  155. return super().generate_compute_metrics(task_name, text2text)