123456789101112131415161718192021 |
- from pytorch_adapt.adapters.base_adapter import BaseGCAdapter
- from pytorch_adapt.adapters.utils import with_opt
- from pytorch_adapt.hooks import ClassifierHook
-
- class ClassifierAdapter(BaseGCAdapter):
- """
- Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook].
-
- |Container|Required keys|
- |---|---|
- |models|```["G", "C"]```|
- |optimizers|```["G", "C"]```|
- """
-
- def init_hook(self, hook_kwargs):
- opts = with_opt(list(self.optimizers.keys()))
- self.hook = self.hook_cls(opts, **hook_kwargs)
-
- @property
- def hook_cls(self):
- return ClassifierHook
|