123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- from collections import OrderedDict
- from typing import Tuple
-
- from ...utils.output_modifier import OutputModifier
- from ...utils.aux_output import AuxOutput
- from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
- from ...model_evaluation.binary_evaluator import BinaryEvaluator
- from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
- from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
- from ...model_evaluation.loss_evaluator import LossEvaluator
- from ...criteria.bb_supervised import BBLoss
- from ...configs.rsna_configs import RSNAConfigs
- from ...models.model import Model
- from ..entrypoint import BaseEntrypoint
- from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory
-
-
- class EntryPoint(BaseEntrypoint):
-
- def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
- config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type)
- model = RSNALAPResNet18(
- pool_factory=get_pool_factory(discriminative_attention=False),
- adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False),
- sigmoid_scale=0.1,
- )
-
- bb_losses = BBLoss(
- 'BB', model, {
- 'att2': model.layer2[0].pool.attention_layer,
- 'att3': model.layer3[0].pool.attention_layer,
- 'att4': model.layer4[0].pool.attention_layer,
- 'att5': model.avgpool[0].attention_layer,
- }
- , 1,
- 'interpretations', 1, 0.5)
- bb_losses.configure(config)
-
- config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
- 'b': BinaryEvaluator,
- 'l': LossEvaluator,
- 'f': BinForetellerEvaluator.standard_creator('foretell'),
- 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
- }))
- config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
-
- ###################################
- ########### Foreteller ############
- ###################################
-
- aux = AuxOutput(model, dict(
- foretell_pool2=model.layer2[0].pool.attention_layer,
- foretell_pool3=model.layer3[0].pool.attention_layer,
- foretell_pool4=model.layer4[0].pool.attention_layer,
- foretell_avgpool=model.avgpool[0].attention_layer,
- ))
- aux.configure(config)
-
- output_modifier = OutputModifier(model,
- lambda x: x.flatten(1).max(dim=-1).values,
- 'foretell_pool2',
- 'foretell_pool3',
- 'foretell_pool4',
- 'foretell_avgpool',
- )
- output_modifier.configure(config)
-
-
- return config, model
|