You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

bb_lap_inception.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from collections import OrderedDict
  2. from typing import Tuple
  3. from ...configs.rsna_configs import RSNAConfigs
  4. from ...models.model import Model
  5. from ..entrypoint import BaseEntrypoint
  6. from ...models.rsna.lap_inception import RSNALAPInception
  7. from ...modules.lap import LAP
  8. from ...modules.adaptive_lap import AdaptiveLAP
  9. from ...criteria.bb_supervised import BBLoss
  10. from ...utils.aux_output import AuxOutput
  11. from ...utils.output_modifier import OutputModifier
  12. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  13. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  14. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  15. from ...model_evaluation.loss_evaluator import LossEvaluator
  16. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  17. def lap_factory(channel):
  18. return LAP(
  19. channel, 2, 2, hidden_channels=[8],
  20. sigmoid_scale=0.1, discriminative_attention=False)
  21. def adaptive_lap_factory(channel):
  22. return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False)
  23. class EntryPoint(BaseEntrypoint):
  24. def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
  25. conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type)
  26. model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory)
  27. bb_losses = BBLoss(
  28. 'BB', model, {
  29. 'Pool2': model.maxpool2.attention_layer,
  30. 'Pool6': model.Mixed_6a.pool.attention_layer,
  31. 'Pool7': model.Mixed_7a.pool.attention_layer,
  32. 'PoolG': model.avgpool[0].attention_layer,
  33. }
  34. , 1,
  35. 'interpretations', 1, 0.5)
  36. bb_losses.configure(conf)
  37. conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  38. 'b': BinaryEvaluator,
  39. 'l': LossEvaluator,
  40. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  41. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  42. }))
  43. conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  44. ###################################
  45. ########### Foreteller ############
  46. ###################################
  47. aux = AuxOutput(model, dict(
  48. foretell_pool2=model.maxpool2.attention_layer,
  49. foretell_pool6=model.Mixed_6a.pool.attention_layer,
  50. foretell_pool7=model.Mixed_7a.pool.attention_layer,
  51. foretell_avgpool=model.avgpool[0].attention_layer,
  52. ))
  53. aux.configure(conf)
  54. output_modifier = OutputModifier(model,
  55. lambda x: x.flatten(1).max(dim=-1).values,
  56. 'foretell_pool2',
  57. 'foretell_pool6',
  58. 'foretell_pool7',
  59. 'foretell_avgpool',
  60. )
  61. output_modifier.configure(conf)
  62. return conf, model