from collections import OrderedDict from typing import Tuple from ...utils.output_modifier import OutputModifier from ...utils.aux_output import AuxOutput from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator from ...model_evaluation.binary_evaluator import BinaryEvaluator from ...model_evaluation.binary_fortelling import BinForetellerEvaluator from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator from ...model_evaluation.loss_evaluator import LossEvaluator from ...criteria.bb_supervised import BBLoss from ...configs.rsna_configs import RSNAConfigs from ...models.model import Model from ..entrypoint import BaseEntrypoint from ...models.rsna.lap_resnet import RSNALAPResNet18, get_adaptive_pool_factory, get_pool_factory class EntryPoint(BaseEntrypoint): def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: config = RSNAConfigs('RSNA_ResNet_BB', 1, 224, self.phase_type) model = RSNALAPResNet18( pool_factory=get_pool_factory(discriminative_attention=False), adaptive_factory=get_adaptive_pool_factory(discriminative_attention=False), sigmoid_scale=0.1, ) bb_losses = BBLoss( 'BB', model, { 'att2': model.layer2[0].pool.attention_layer, 'att3': model.layer3[0].pool.attention_layer, 'att4': model.layer4[0].pool.attention_layer, 'att5': model.avgpool[0].attention_layer, } , 1, 'interpretations', 1, 0.5) bb_losses.configure(config) config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ 'b': BinaryEvaluator, 'l': LossEvaluator, 'f': BinForetellerEvaluator.standard_creator('foretell'), 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), })) config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' ################################### ########### Foreteller ############ ################################### aux = AuxOutput(model, dict( foretell_pool2=model.layer2[0].pool.attention_layer, foretell_pool3=model.layer3[0].pool.attention_layer, foretell_pool4=model.layer4[0].pool.attention_layer, foretell_avgpool=model.avgpool[0].attention_layer, )) aux.configure(config) output_modifier = OutputModifier(model, lambda x: x.flatten(1).max(dim=-1).values, 'foretell_pool2', 'foretell_pool3', 'foretell_pool4', 'foretell_avgpool', ) output_modifier.configure(config) return config, model