You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

opendelta.py 2.1KB

3 months ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from os import PathLike
  2. from pathlib import Path
  3. from typing import Dict, Union, Optional, Iterable
  4. import numpy as np
  5. import torch
  6. from torch import Tensor
  7. from torch.utils.data import Dataset
  8. from sklearn.metrics import classification_report
  9. from transformers import TrainingArguments, BertForSequenceClassification, EvalPrediction, Trainer
  10. from opendelta import AdapterModel
  11. class OpenDeltaModelWrapper:
  12. def __init__(self, base_model_name: Union[str, PathLike[str]], mask_token_id: int = -100):
  13. self.model = BertForSequenceClassification.from_pretrained(str(base_model_name))
  14. self.mask_token_id = mask_token_id
  15. def load_adapters(self, adapter_path: str, adapter_names: Iterable[str], with_heads: bool = True) -> None:
  16. # TODO
  17. pass
  18. def add_classification_adapter(self, adapter_name: str, bottleneck_dim: int) -> None:
  19. # TODO
  20. self.delta_model = AdapterModel(base_model, bottleneck_dim=48)
  21. # leave the delta tuning modules and the newly initialized classification head tunable.
  22. def __compute_metrics(self, pred: EvalPrediction) -> Dict[str, float]:
  23. true_labels = pred.label_ids.ravel()
  24. pred_labels = pred.predictions.argmax(-1).ravel()
  25. report = classification_report(true_labels, pred_labels, output_dict=True)
  26. return {
  27. 'accuracy': report['accuracy'],
  28. 'f1-score-1': report['1']['f1-score'],
  29. 'f1-score-ma': report['macro avg']['f1-score']
  30. }
  31. def finetune_adapter(
  32. self, adapter_name: str,
  33. train_dataset: Dataset,
  34. eval_dataset: Dataset,
  35. col_fn,
  36. training_args=None
  37. ):
  38. self.delta_model.freeze_module(exclude=["deltas", "classifier"]) # freeze other adapters and unfreeze selected adapter
  39. self.__finetune(train_dataset, eval_dataset, col_fn, training_args)
  40. def evaluate_adapter(
  41. self,
  42. adapter_name: str,
  43. eval_dataset: Dataset,
  44. col_fn,
  45. eval_batch_size: int = 32
  46. ) -> Dict[str, float]:
  47. # TODO
  48. pass
  49. def inference_adapter(self, adapter_name: str, input_ids, attention_mask) -> Tensor:
  50. # TODO
  51. pass