You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ws_lap_inception.py 4.8KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from typing import Tuple
  2. from collections import OrderedDict
  3. from ...configs.rsna_configs import RSNAConfigs
  4. from ...models.model import Model
  5. from ..entrypoint import BaseEntrypoint
  6. from ...models.rsna.lap_inception import RSNALAPInception
  7. from ...modules.lap import LAP
  8. from ...modules.adaptive_lap import AdaptiveLAP
  9. from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
  10. from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
  11. from ...utils.aux_output import AuxOutput
  12. from ...utils.output_modifier import OutputModifier
  13. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  14. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  15. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  16. from ...model_evaluation.loss_evaluator import LossEvaluator
  17. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  18. def lap_factory(channel):
  19. return LAP(
  20. channel, 2, 2, hidden_channels=[8],
  21. sigmoid_scale=0.1, discriminative_attention=True)
  22. def adaptive_lap_factory(channel):
  23. return AdaptiveLAP(channel, sigmoid_scale=0.1, discriminative_attention=True)
  24. class EntryPoint(BaseEntrypoint):
  25. def __init__(self, phase_type) -> None:
  26. self.active_min_ratio: float = 0.1
  27. self.active_max_ratio: float = 0.5
  28. self.inactive_ratio: float = 0.1
  29. super().__init__(phase_type)
  30. def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
  31. conf = RSNAConfigs('RSNA_Inception_WS', 4, 256, self.phase_type)
  32. model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory)
  33. # weakly supervised losses based on discriminative head and attention head
  34. ws_losses = [
  35. DiscriminativeWeaklySupervisedLoss(
  36. title, model, att_score_layer, 0.25,
  37. self.inactive_ratio,
  38. (self.active_min_ratio, self.active_max_ratio), {0: [1]},
  39. discr_score_layer,
  40. w_attention_in_ordering=0.2, w_discr_in_ordering=1)
  41. for title, att_score_layer, discr_score_layer in [
  42. ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
  43. ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
  44. ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
  45. ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  46. ]
  47. ]
  48. for ws_loss in ws_losses:
  49. ws_loss.configure(conf)
  50. # concordance loss for attention
  51. concordance_loss = PoolConcordanceLossCalculator(
  52. 'AC', model, OrderedDict([
  53. ('att2', model.maxpool2.attention_layer),
  54. ('att6', model.Mixed_6a.pool.attention_layer),
  55. ('att7', model.Mixed_7a.pool.attention_layer),
  56. ('attG', model.avgpool[0].attention_layer),
  57. ]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1,
  58. labels_by_channel={0: [1]})
  59. concordance_loss.configure(conf)
  60. # concordance loss for discrimination head
  61. concordance_loss2 = PoolConcordanceLossCalculator(
  62. 'DC', model, OrderedDict([
  63. ('D-att2', model.maxpool2.discrimination_layer),
  64. ('D-att6', model.Mixed_6a.pool.discrimination_layer),
  65. ('D-att7', model.Mixed_7a.pool.discrimination_layer),
  66. ('D-attG', model.avgpool[0].discrimination_layer),
  67. ]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0,
  68. labels_by_channel={0: [1]})
  69. concordance_loss2.configure(conf)
  70. conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  71. 'b': BinaryEvaluator,
  72. 'l': LossEvaluator,
  73. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  74. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  75. }))
  76. conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  77. ###################################
  78. ########### Foreteller ############
  79. ###################################
  80. aux = AuxOutput(model, dict(
  81. foretell_pool2=model.maxpool2.attention_layer,
  82. foretell_pool6=model.Mixed_6a.pool.attention_layer,
  83. foretell_pool7=model.Mixed_7a.pool.attention_layer,
  84. foretell_avgpool=model.avgpool[0].attention_layer,
  85. ))
  86. aux.configure(conf)
  87. output_modifier = OutputModifier(model,
  88. lambda x: x.flatten(1).max(dim=-1).values,
  89. 'foretell_pool2',
  90. 'foretell_pool6',
  91. 'foretell_pool7',
  92. 'foretell_avgpool',
  93. )
  94. output_modifier.configure(conf)
  95. return conf, model