You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_maker.py 2.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """
  2. Creates an interpreter object and returns it
  3. """
  4. from typing import Callable
  5. from . import InterpretableModel, Interpreter
  6. class InterpretationType:
  7. GuidedBackprop = 'GuidedBackprop'
  8. GuidedGradCam = 'GuidedGradCam'
  9. GradCam = 'GradCam'
  10. RelCam = 'RelCam'
  11. DeepLift = 'DeepLift'
  12. Attention = 'Attention'
  13. AttentionSmooth = 'AttentionSmooth'
  14. AttentionSum = 'AttentionSum'
  15. ImagenetAttention = 'ImagenetAttention'
  16. GT = 'GT'
  17. InterpreterMaker = Callable[[InterpretableModel], Interpreter]
  18. def create_interpreter(interpretation_method: str, model: InterpretableModel) -> Interpreter:
  19. """
  20. Creates an interpreter object and returns it
  21. :param interpretation_method: The method name
  22. :param model: The model to be interpreted
  23. :return: an interpreter object that can run the model and interpret its output
  24. """
  25. if interpretation_method == InterpretationType.GuidedBackprop:
  26. from .guided_backprop import GuidedBackprop as interpreter_maker
  27. elif interpretation_method == InterpretationType.GuidedGradCam:
  28. from .guided_gradcam import GuidedGradCam as interpreter_maker
  29. elif interpretation_method == InterpretationType.GradCam:
  30. from .gradcam import GradCam as interpreter_maker
  31. elif interpretation_method == InterpretationType.RelCam:
  32. from .relcam import RelCamInterpreter as interpreter_maker
  33. elif interpretation_method == InterpretationType.DeepLift:
  34. from .deep_lift import DeepLift as interpreter_maker
  35. elif interpretation_method == InterpretationType.Attention:
  36. from .attention_interpreter import AttentionInterpreter as interpreter_maker
  37. elif interpretation_method == InterpretationType.AttentionSmooth:
  38. from .attention_interpreter_smooth_integrator import AttentionInterpreterSmoothIntegrator as interpreter_maker
  39. elif interpretation_method == InterpretationType.AttentionSum:
  40. from .attention_sum_interpreter import AttentionSumInterpreter as interpreter_maker
  41. elif interpretation_method == InterpretationType.ImagenetAttention:
  42. from .imagenet_attention_interpreter import ImagenetPredictionInterpreter as interpreter_maker
  43. else:
  44. raise Exception('Unknown interpretation method ', interpretation_method)
  45. return interpreter_maker(model)