You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

adapterhub.py 5.2KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from os import PathLike
  2. from pathlib import Path
  3. from typing import Dict, Union, Optional, Iterable
  4. import numpy as np
  5. import torch
  6. from torch import Tensor
  7. from torch.utils.data import Dataset
  8. from sklearn.metrics import classification_report
  9. from transformers import TrainingArguments, BertAdapterModel, EvalPrediction, AdapterTrainer
  10. from transformers.adapters import Fuse
  11. class BertAdapterModelWrapper:
  12. def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int = -100):
  13. self.model = BertAdapterModel.from_pretrained(str(base_model_name))
  14. self.mask_token_id = mask_token_id
  15. @property
  16. def enabled_fusion(self) -> bool:
  17. return len(self.model.config.adapters.fusions) != 0
  18. @property
  19. def active_head_configs(self) -> dict:
  20. if self.model.active_head is None:
  21. return {}
  22. return self.model.config.prediction_heads[self.model.active_head]
  23. @property
  24. def __fuse_all_adapters(self) -> Fuse:
  25. adapters = list(self.model.config.adapters)
  26. return Fuse(*adapters)
  27. def load_adapters(self, adapter_path: str, adapter_names: Iterable[str], with_heads: bool = True) -> None:
  28. for name in adapter_names:
  29. path = Path(adapter_path) / name
  30. self.model.load_adapter(str(path), with_head=with_heads)
  31. def add_classification_adapter(self, adapter_name: str, num_labels: int) -> None:
  32. if self.enabled_fusion:
  33. raise Exception("Model has a fusion layer and you cannot add adapters to it!!!")
  34. self.model.add_adapter(adapter_name)
  35. self.model.add_classification_head(
  36. adapter_name,
  37. num_labels=num_labels
  38. )
  39. def remove_heads_and_add_fusion(self, head_name: str, num_labels: int) -> None:
  40. self.model.add_adapter_fusion(self.__fuse_all_adapters)
  41. self.model.set_active_adapters(self.__fuse_all_adapters)
  42. for head in list(self.model.heads.keys()):
  43. self.model.delete_head(head)
  44. self.model.add_tagging_head(
  45. head_name,
  46. num_labels=num_labels
  47. )
  48. def __compute_metrics(self, pred: EvalPrediction) -> Dict[str, float]:
  49. true_labels = pred.label_ids.ravel()
  50. pred_labels = pred.predictions.argmax(-1).ravel()
  51. report = classification_report(true_labels, pred_labels, output_dict=True)
  52. return {
  53. 'accuracy': report['accuracy'],
  54. 'f1-score-1': report['1']['f1-score'],
  55. 'f1-score-ma': report['macro avg']['f1-score']
  56. }
  57. def __finetune(
  58. self,
  59. train_dataset: Dataset,
  60. eval_dataset: Dataset,
  61. col_fn,
  62. training_args: Optional[dict]
  63. ) -> None:
  64. if training_args is None:
  65. training_args = {}
  66. training_args = TrainingArguments(
  67. evaluation_strategy="epoch",
  68. save_strategy="epoch",
  69. # The next 2 lines are important to ensure the dataset labels are properly passed to the model
  70. remove_unused_columns=False,
  71. **training_args
  72. )
  73. trainer = AdapterTrainer(
  74. model=self.model,
  75. args=training_args,
  76. train_dataset=train_dataset,
  77. eval_dataset=eval_dataset,
  78. data_collator=col_fn,
  79. compute_metrics=self.__compute_metrics
  80. )
  81. trainer.train()
  82. def finetune_adapter(
  83. self, adapter_name: str,
  84. train_dataset: Dataset,
  85. eval_dataset: Dataset,
  86. col_fn,
  87. training_args=None
  88. ):
  89. self.model.train_adapter(adapter_name) # freeze other adapters and unfreeze selected adapter
  90. self.__finetune(train_dataset, eval_dataset, col_fn, training_args)
  91. def finetune_fusion(
  92. self,
  93. head_name: str,
  94. train_dataset: Dataset,
  95. eval_dataset: Dataset,
  96. col_fn,
  97. training_args=None
  98. ):
  99. if not self.enabled_fusion:
  100. raise Exception("You must have a fusion layer to do that!")
  101. self.model.train_adapter_fusion(self.__fuse_all_adapters)
  102. self.model.active_head = head_name
  103. self.__finetune(train_dataset, eval_dataset, col_fn, training_args)
  104. def evaluate_adapter(
  105. self,
  106. adapter_name: str,
  107. eval_dataset: Dataset,
  108. col_fn,
  109. eval_batch_size: int = 32
  110. ) -> Dict[str, float]:
  111. self.model.set_active_adapters(adapter_name)
  112. training_args = TrainingArguments(
  113. output_dir='.',
  114. remove_unused_columns=False,
  115. label_names=['labels'],
  116. per_device_eval_batch_size=eval_batch_size
  117. )
  118. trainer = AdapterTrainer(
  119. model=self.model,
  120. args=training_args,
  121. data_collator=col_fn,
  122. compute_metrics=self.__compute_metrics
  123. )
  124. return trainer.evaluate(eval_dataset)
  125. def inference_adapter(self, adapter_name: str, input_ids, attention_mask) -> Tensor:
  126. self.model.eval()
  127. self.model.set_active_adapters(adapter_name)
  128. with torch.no_grad():
  129. model_output = self.model(
  130. input_ids=input_ids,
  131. attention_mask=attention_mask
  132. )
  133. return torch.softmax(model_output.logits, dim=2)