You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

classifier_adapter.py 610B

123456789101112131415161718192021
  1. from pytorch_adapt.adapters.base_adapter import BaseGCAdapter
  2. from pytorch_adapt.adapters.utils import with_opt
  3. from pytorch_adapt.hooks import ClassifierHook
  4. class ClassifierAdapter(BaseGCAdapter):
  5. """
  6. Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook].
  7. |Container|Required keys|
  8. |---|---|
  9. |models|```["G", "C"]```|
  10. |optimizers|```["G", "C"]```|
  11. """
  12. def init_hook(self, hook_kwargs):
  13. opts = with_opt(list(self.optimizers.keys()))
  14. self.hook = self.hook_cls(opts, **hook_kwargs)
  15. @property
  16. def hook_cls(self):
  17. return ClassifierHook