from os import PathLike from pathlib import Path from typing import Dict, Union, Optional, Iterable import numpy as np import torch from torch import Tensor from torch.utils.data import Dataset from sklearn.metrics import classification_report from transformers import TrainingArguments, BertAdapterModel, EvalPrediction, AdapterTrainer from transformers.adapters import Fuse class BertAdapterModelWrapper: def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int = -100): self.model = BertAdapterModel.from_pretrained(str(base_model_name)) self.mask_token_id = mask_token_id @property def enabled_fusion(self) -> bool: return len(self.model.config.adapters.fusions) != 0 @property def active_head_configs(self) -> dict: if self.model.active_head is None: return {} return self.model.config.prediction_heads[self.model.active_head] @property def __fuse_all_adapters(self) -> Fuse: adapters = list(self.model.config.adapters) return Fuse(*adapters) def load_adapters(self, adapter_path: str, adapter_names: Iterable[str], with_heads: bool = True) -> None: for name in adapter_names: path = Path(adapter_path) / name self.model.load_adapter(str(path), with_head=with_heads) def add_classification_adapter(self, adapter_name: str, num_labels: int) -> None: if self.enabled_fusion: raise Exception("Model has a fusion layer and you cannot add adapters to it!!!") self.model.add_adapter(adapter_name) self.model.add_classification_head( adapter_name, num_labels=num_labels ) def remove_heads_and_add_fusion(self, head_name: str, num_labels: int) -> None: self.model.add_adapter_fusion(self.__fuse_all_adapters) self.model.set_active_adapters(self.__fuse_all_adapters) for head in list(self.model.heads.keys()): self.model.delete_head(head) self.model.add_tagging_head( head_name, num_labels=num_labels ) def __compute_metrics(self, pred: EvalPrediction) -> Dict[str, float]: true_labels = pred.label_ids.ravel() pred_labels = pred.predictions.argmax(-1).ravel() report = classification_report(true_labels, pred_labels, output_dict=True) return { 'accuracy': report['accuracy'], 'f1-score-1': report['1']['f1-score'], 'f1-score-ma': report['macro avg']['f1-score'] } def __finetune( self, train_dataset: Dataset, eval_dataset: Dataset, col_fn, training_args: Optional[dict] ) -> None: if training_args is None: training_args = {} training_args = TrainingArguments( evaluation_strategy="epoch", save_strategy="epoch", # The next 2 lines are important to ensure the dataset labels are properly passed to the model remove_unused_columns=False, **training_args ) trainer = AdapterTrainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=col_fn, compute_metrics=self.__compute_metrics ) trainer.train() def finetune_adapter( self, adapter_name: str, train_dataset: Dataset, eval_dataset: Dataset, col_fn, training_args=None ): self.model.train_adapter(adapter_name) # freeze other adapters and unfreeze selected adapter self.__finetune(train_dataset, eval_dataset, col_fn, training_args) def finetune_fusion( self, head_name: str, train_dataset: Dataset, eval_dataset: Dataset, col_fn, training_args=None ): if not self.enabled_fusion: raise Exception("You must have a fusion layer to do that!") self.model.train_adapter_fusion(self.__fuse_all_adapters) self.model.active_head = head_name self.__finetune(train_dataset, eval_dataset, col_fn, training_args) def evaluate_adapter( self, adapter_name: str, eval_dataset: Dataset, col_fn, eval_batch_size: int = 32 ) -> Dict[str, float]: self.model.set_active_adapters(adapter_name) training_args = TrainingArguments( output_dir='.', remove_unused_columns=False, label_names=['labels'], per_device_eval_batch_size=eval_batch_size ) trainer = AdapterTrainer( model=self.model, args=training_args, data_collator=col_fn, compute_metrics=self.__compute_metrics ) return trainer.evaluate(eval_dataset) def inference_adapter(self, adapter_name: str, input_ids, attention_mask) -> Tensor: self.model.eval() self.model.set_active_adapters(adapter_name) with torch.no_grad(): model_output = self.model( input_ids=input_ids, attention_mask=attention_mask ) return torch.softmax(model_output.logits, dim=2)