You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bb_lap_resnet.py 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from collections import OrderedDict
  2. from typing import Tuple
  3. from ...utils.output_modifier import OutputModifier
  4. from ...utils.aux_output import AuxOutput
  5. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  6. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  7. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  8. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  9. from ...model_evaluation.loss_evaluator import LossEvaluator
  10. from ...criteria.bb_supervised import BBLoss
  11. from ...configs.rsna_configs import RSNAConfigs
  12. from ...models.model import Model
  13. from ..entrypoint import BaseEntrypoint
  14. from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory
  15. class EntryPoint(BaseEntrypoint):
  16. def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
  17. config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type)
  18. model = RSNALAPResNet18(
  19. pool_factory=get_pool_factory(discriminative_attention=False),
  20. adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False),
  21. sigmoid_scale=0.1,
  22. )
  23. bb_losses = BBLoss(
  24. 'BB', model, {
  25. 'att2': model.layer2[0].pool.attention_layer,
  26. 'att3': model.layer3[0].pool.attention_layer,
  27. 'att4': model.layer4[0].pool.attention_layer,
  28. 'att5': model.avgpool[0].attention_layer,
  29. }
  30. , 1,
  31. 'interpretations', 1, 0.5)
  32. bb_losses.configure(config)
  33. config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  34. 'b': BinaryEvaluator,
  35. 'l': LossEvaluator,
  36. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  37. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  38. }))
  39. config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  40. ###################################
  41. ########### Foreteller ############
  42. ###################################
  43. aux = AuxOutput(model, dict(
  44. foretell_pool2=model.layer2[0].pool.attention_layer,
  45. foretell_pool3=model.layer3[0].pool.attention_layer,
  46. foretell_pool4=model.layer4[0].pool.attention_layer,
  47. foretell_avgpool=model.avgpool[0].attention_layer,
  48. ))
  49. aux.configure(config)
  50. output_modifier = OutputModifier(model,
  51. lambda x: x.flatten(1).max(dim=-1).values,
  52. 'foretell_pool2',
  53. 'foretell_pool3',
  54. 'foretell_pool4',
  55. 'foretell_avgpool',
  56. )
  57. output_modifier.configure(config)
  58. return config, model