You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ws_lap_resnet.py 4.5KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from collections import OrderedDict
  2. from typing import Tuple
  3. from ...utils.output_modifier import OutputModifier
  4. from ...utils.aux_output import AuxOutput
  5. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  6. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  7. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  8. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  9. from ...model_evaluation.loss_evaluator import LossEvaluator
  10. from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
  11. from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
  12. from ...configs.rsna_configs import RSNAConfigs
  13. from ...models.model import Model
  14. from ..entrypoint import BaseEntrypoint
  15. from ...models.rsna.lap_resnet import RSNALAPResNet18
  16. class EntryPoint(BaseEntrypoint):
  17. def __init__(self, phase_type) -> None:
  18. self.active_min_ratio: float = 0.1
  19. self.active_max_ratio: float = 0.5
  20. self.inactive_ratio: float = 0.1
  21. super().__init__(phase_type)
  22. def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
  23. config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type)
  24. model = RSNALAPResNet18()
  25. # weakly supervised losses based on discriminative head and attention head
  26. ws_losses = [
  27. DiscriminativeWeaklySupervisedLoss(
  28. title, model, att_score_layer, 0.25,
  29. self.inactive_ratio,
  30. (self.active_min_ratio, self.active_max_ratio), {0: [1]},
  31. discr_score_layer,
  32. w_attention_in_ordering=0.2, w_discr_in_ordering=1)
  33. for title, att_score_layer, discr_score_layer in [
  34. ('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
  35. ('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
  36. ('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
  37. ('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  38. ]
  39. ]
  40. for ws_loss in ws_losses:
  41. ws_loss.configure(config)
  42. # concordance loss for attention
  43. concordance_loss = PoolConcordanceLossCalculator(
  44. 'AC', model, OrderedDict([
  45. ('att2', model.layer2[0].pool.attention_layer),
  46. ('att3', model.layer3[0].pool.attention_layer),
  47. ('att4', model.layer4[0].pool.attention_layer),
  48. ('att5', model.avgpool[0].attention_layer),
  49. ]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1,
  50. labels_by_channel={0: [1]})
  51. concordance_loss.configure(config)
  52. # concordance loss for discrimination head
  53. concordance_loss2 = PoolConcordanceLossCalculator(
  54. 'DC', model, OrderedDict([
  55. ('att2', model.layer2[0].pool.discrimination_layer),
  56. ('att3', model.layer3[0].pool.discrimination_layer),
  57. ('att4', model.layer4[0].pool.discrimination_layer),
  58. ('att5', model.avgpool[0].discrimination_layer),
  59. ]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0,
  60. labels_by_channel={0: [1]})
  61. concordance_loss2.configure(config)
  62. config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  63. 'b': BinaryEvaluator,
  64. 'l': LossEvaluator,
  65. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  66. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  67. }))
  68. config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  69. ###################################
  70. ########### Foreteller ############
  71. ###################################
  72. aux = AuxOutput(model, dict(
  73. foretell_pool2=model.layer2[0].pool.attention_layer,
  74. foretell_pool3=model.layer3[0].pool.attention_layer,
  75. foretell_pool4=model.layer4[0].pool.attention_layer,
  76. foretell_avgpool=model.avgpool[0].attention_layer,
  77. ))
  78. aux.configure(config)
  79. output_modifier = OutputModifier(model,
  80. lambda x: x.flatten(1).max(dim=-1).values,
  81. 'foretell_pool2',
  82. 'foretell_pool3',
  83. 'foretell_pool4',
  84. 'foretell_avgpool',
  85. )
  86. output_modifier.configure(config)
  87. return config, model