|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- """
- Creates an interpreter object and returns it
- """
- from typing import Callable
- from . import InterpretableModel, Interpreter
-
-
- class InterpretationType:
- GuidedBackprop = 'GuidedBackprop'
- GuidedGradCam = 'GuidedGradCam'
- GradCam = 'GradCam'
- RelCam = 'RelCam'
- DeepLift = 'DeepLift'
- Attention = 'Attention'
- AttentionSmooth = 'AttentionSmooth'
- AttentionSum = 'AttentionSum'
- ImagenetAttention = 'ImagenetAttention'
- GT = 'GT'
-
-
- InterpreterMaker = Callable[[InterpretableModel], Interpreter]
-
-
- def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter:
- """
- Creates an interpreter object and returns it
- :param interpretation_method: The method name
- :param model: The model to be interpreted
- :return: an interpreter object that can run the model and interpret its output
- """
-
- if interpretation_method == InterpretationType.GuidedBackprop:
- from .guided_backprop import GuidedBackprop as interpreter_maker
-
- elif interpretation_method == InterpretationType.GuidedGradCam:
- from .guided_gradcam import GuidedGradCam as interpreter_maker
-
- elif interpretation_method == InterpretationType.GradCam:
- from .gradcam import GradCam as interpreter_maker
-
- elif interpretation_method == InterpretationType.RelCam:
- from .relcam import RelCamInterpreter as interpreter_maker
-
- elif interpretation_method == InterpretationType.DeepLift:
- from .deep_lift import DeepLift as interpreter_maker
-
- elif interpretation_method == InterpretationType.Attention:
- from .attention_interpreter import AttentionInterpreter as interpreter_maker
-
- elif interpretation_method == InterpretationType.AttentionSmooth:
- from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker
-
- elif interpretation_method == InterpretationType.AttentionSum:
- from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker
-
- elif interpretation_method == InterpretationType.ImagenetAttention:
- from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker
-
- else:
- raise Exception('Unknown interpretation method ', interpretation_method)
-
- return interpreter_maker(model)
-
|