""" Creates an interpreter object and returns it """ from typing import Callable from . import InterpretableModel, Interpreter class InterpretationType: GuidedBackprop = 'GuidedBackprop' GuidedGradCam = 'GuidedGradCam' GradCam = 'GradCam' RelCam = 'RelCam' DeepLift = 'DeepLift' Attention = 'Attention' AttentionSmooth = 'AttentionSmooth' AttentionSum = 'AttentionSum' ImagenetAttention = 'ImagenetAttention' GT = 'GT' InterpreterMaker = Callable[[InterpretableModel], Interpreter] def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter: """ Creates an interpreter object and returns it :param interpretation_method: The method name :param model: The model to be interpreted :return: an interpreter object that can run the model and interpret its output """ if interpretation_method == InterpretationType.GuidedBackprop: from .guided_backprop import GuidedBackprop as interpreter_maker elif interpretation_method == InterpretationType.GuidedGradCam: from .guided_gradcam import GuidedGradCam as interpreter_maker elif interpretation_method == InterpretationType.GradCam: from .gradcam import GradCam as interpreter_maker elif interpretation_method == InterpretationType.RelCam: from .relcam import RelCamInterpreter as interpreter_maker elif interpretation_method == InterpretationType.DeepLift: from .deep_lift import DeepLift as interpreter_maker elif interpretation_method == InterpretationType.Attention: from .attention_interpreter import AttentionInterpreter as interpreter_maker elif interpretation_method == InterpretationType.AttentionSmooth: from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker elif interpretation_method == InterpretationType.AttentionSum: from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker elif interpretation_method == InterpretationType.ImagenetAttention: from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker else: raise Exception('Unknown interpretation method ', interpretation_method) return interpreter_maker(model)