12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- from os import PathLike
- from pathlib import Path
- from typing import Dict, Union, Optional, Iterable
-
- import numpy as np
- import torch
- from torch import Tensor
- from torch.utils.data import Dataset
- from sklearn.metrics import classification_report
- from transformers import TrainingArguments, BertForSequenceClassification, EvalPrediction, Trainer
- from opendelta import AdapterModel
-
-
- class OpenDeltaModelWrapper:
- def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int = -100):
- self.model = BertForSequenceClassification.from_pretrained(str(base_model_name))
- self.mask_token_id = mask_token_id
-
- def load_adapters(self, adapter_path: str, adapter_names: Iterable[str], with_heads: bool = True) -> None:
- # TODO
- pass
-
- def add_classification_adapter(self, adapter_name: str, bottleneck_dim: int) -> None:
- # TODO
- self.delta_model = AdapterModel(base_model, bottleneck_dim=48)
- # leave the delta tuning modules and the newly initialized classification head tunable.
-
- def __compute_metrics(self, pred: EvalPrediction) -> Dict[str, float]:
- true_labels = pred.label_ids.ravel()
- pred_labels = pred.predictions.argmax(-1).ravel()
- report = classification_report(true_labels, pred_labels, output_dict=True)
- return {
- 'accuracy': report['accuracy'],
- 'f1-score-1': report['1']['f1-score'],
- 'f1-score-ma': report['macro avg']['f1-score']
- }
-
- def finetune_adapter(
- self, adapter_name: str,
- train_dataset: Dataset,
- eval_dataset: Dataset,
- col_fn,
- training_args=None
- ):
- self.delta_model.freeze_module(exclude=["deltas", "classifier"]) # freeze other adapters and unfreeze selected adapter
- self.__finetune(train_dataset, eval_dataset, col_fn, training_args)
-
-
- def evaluate_adapter(
- self,
- adapter_name: str,
- eval_dataset: Dataset,
- col_fn,
- eval_batch_size: int = 32
- ) -> Dict[str, float]:
- # TODO
- pass
-
- def inference_adapter(self, adapter_name: str, input_ids, attention_mask) -> Tensor:
- # TODO
- pass
|