from collections import OrderedDict from typing import Tuple from ...configs.rsna_configs import RSNAConfigs from ...models.model import Model from ..entrypoint import BaseEntrypoint from ...models.rsna.lap_inception import RSNALAPInception from ...modules.lap import LAP from ...modules.adaptive_lap import AdaptiveLAP from ...criteria.bb_supervised import BBLoss from ...utils.aux_output import AuxOutput from ...utils.output_modifier import OutputModifier from ...model_evaluation.binary_evaluator import BinaryEvaluator from ...model_evaluation.binary_fortelling import BinForetellerEvaluator from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator from ...model_evaluation.loss_evaluator import LossEvaluator from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator def lap_factory(channel): return LAP( channel, 2, 2, hidden_channels=[8], sigmoid_scale=0.1, discriminative_attention=False) def adaptive_lap_factory(channel): return AdaptiveLAP(channel, [], 0.1, discriminative_attention=False) class EntryPoint(BaseEntrypoint): def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: conf = RSNAConfigs('RSNA_Inception_BB', 6, 256, self.phase_type) model = RSNALAPInception(0.4, lap_factory, adaptive_lap_factory) bb_losses = BBLoss( 'BB', model, { 'Pool2': model.maxpool2.attention_layer, 'Pool6': model.Mixed_6a.pool.attention_layer, 'Pool7': model.Mixed_7a.pool.attention_layer, 'PoolG': model.avgpool[0].attention_layer, } , 1, 'interpretations', 1, 0.5) bb_losses.configure(conf) conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ 'b': BinaryEvaluator, 'l': LossEvaluator, 'f': BinForetellerEvaluator.standard_creator('foretell'), 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), })) conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' ################################### ########### Foreteller ############ ################################### aux = AuxOutput(model, dict( foretell_pool2=model.maxpool2.attention_layer, foretell_pool6=model.Mixed_6a.pool.attention_layer, foretell_pool7=model.Mixed_7a.pool.attention_layer, foretell_avgpool=model.avgpool[0].attention_layer, )) aux.configure(conf) output_modifier = OutputModifier(model, lambda x: x.flatten(1).max(dim=-1).values, 'foretell_pool2', 'foretell_pool6', 'foretell_pool7', 'foretell_avgpool', ) output_modifier.configure(conf) return conf, model