1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from collections import OrderedDict
- from typing import Tuple
- from ...configs.rsna_configs import RSNAConfigs
- from ...models.model import Model
- from ..entrypoint import BaseEntrypoint
- from ...models.rsna.lap_inception import RSNALAPInception
- from ...modules.lap import LAP
- from ...modules.adaptive_lap import AdaptiveLAP
- from ...criteria.bb_supervised import BBLoss
- from ...utils.aux_output import AuxOutput
- from ...utils.output_modifier import OutputModifier
- from ...model_evaluation.binary_evaluator import BinaryEvaluator
- from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
- from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
- from ...model_evaluation.loss_evaluator import LossEvaluator
- from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
-
-
- def lap_factory(channel):
- return LAP(
- channel, 2, 2, hidden_channels=[8],
- sigmoid_scale=0.1, discriminative_attention=False)
-
-
- def adaptive_lap_factory(channel):
- return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False)
-
-
- class EntryPoint(BaseEntrypoint):
-
- def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]:
- conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type)
- model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory)
-
- bb_losses = BBLoss(
- 'BB', model, {
- 'Pool2': model.maxpool2.attention_layer,
- 'Pool6': model.Mixed_6a.pool.attention_layer,
- 'Pool7': model.Mixed_7a.pool.attention_layer,
- 'PoolG': model.avgpool[0].attention_layer,
- }
- , 1,
- 'interpretations', 1, 0.5)
- bb_losses.configure(conf)
-
- conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
- 'b': BinaryEvaluator,
- 'l': LossEvaluator,
- 'f': BinForetellerEvaluator.standard_creator('foretell'),
- 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
- }))
- conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
-
- ###################################
- ########### Foreteller ############
- ###################################
-
- aux = AuxOutput(model, dict(
- foretell_pool2=model.maxpool2.attention_layer,
- foretell_pool6=model.Mixed_6a.pool.attention_layer,
- foretell_pool7=model.Mixed_7a.pool.attention_layer,
- foretell_avgpool=model.avgpool[0].attention_layer,
- ))
- aux.configure(conf)
-
-
- output_modifier = OutputModifier(model,
- lambda x: x.flatten(1).max(dim=-1).values,
- 'foretell_pool2',
- 'foretell_pool6',
- 'foretell_pool7',
- 'foretell_avgpool',
- )
- output_modifier.configure(conf)
-
- return conf, model
|