from collections import OrderedDict from typing import Tuple from ...utils.output_modifier import OutputModifier from ...utils.aux_output import AuxOutput from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator from ...model_evaluation.binary_evaluator import BinaryEvaluator from ...model_evaluation.binary_fortelling import BinForetellerEvaluator from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator from ...model_evaluation.loss_evaluator import LossEvaluator from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss from ...configs.rsna_configs import RSNAConfigs from ...models.model import Model from ..entrypoint import BaseEntrypoint from ...models.rsna.lap_resnet import RSNALAPResNet18 class EntryPoint(BaseEntrypoint): def __init__(self, phase_type) -> None: self.active_min_ratio: float = 0.1 self.active_max_ratio: float = 0.5 self.inactive_ratio: float = 0.1 super().__init__(phase_type) def _get_conf_model(self) -> Tuple[RSNAConfigs, Model]: config = RSNAConfigs('RSNA_ResNet_WS', 1, 224, self.phase_type) model = RSNALAPResNet18() # weakly supervised losses based on discriminative head and attention head ws_losses = [ DiscriminativeWeaklySupervisedLoss( title, model, att_score_layer, 0.25, self.inactive_ratio, (self.active_min_ratio, self.active_max_ratio), {0: [1]}, discr_score_layer, w_attention_in_ordering=0.2, w_discr_in_ordering=1) for title, att_score_layer, discr_score_layer in [ ('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer), ('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer), ('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer), ('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), ] ] for ws_loss in ws_losses: ws_loss.configure(config) # concordance loss for attention concordance_loss = PoolConcordanceLossCalculator( 'AC', model, OrderedDict([ ('att2', model.layer2[0].pool.attention_layer), ('att3', model.layer3[0].pool.attention_layer), ('att4', model.layer4[0].pool.attention_layer), ('att5', model.avgpool[0].attention_layer), ]), loss_weight=1, weights=1 / 4, diff_thresholds=0.1, labels_by_channel={0: [1]}) concordance_loss.configure(config) # concordance loss for discrimination head concordance_loss2 = PoolConcordanceLossCalculator( 'DC', model, OrderedDict([ ('att2', model.layer2[0].pool.discrimination_layer), ('att3', model.layer3[0].pool.discrimination_layer), ('att4', model.layer4[0].pool.discrimination_layer), ('att5', model.avgpool[0].discrimination_layer), ]), loss_weight=1, weights=0.5 / 4, diff_thresholds=0, labels_by_channel={0: [1]}) concordance_loss2.configure(config) config.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ 'b': BinaryEvaluator, 'l': LossEvaluator, 'f': BinForetellerEvaluator.standard_creator('foretell'), 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), })) config.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' ################################### ########### Foreteller ############ ################################### aux = AuxOutput(model, dict( foretell_pool2=model.layer2[0].pool.attention_layer, foretell_pool3=model.layer3[0].pool.attention_layer, foretell_pool4=model.layer4[0].pool.attention_layer, foretell_avgpool=model.avgpool[0].attention_layer, )) aux.configure(config) output_modifier = OutputModifier(model, lambda x: x.flatten(1).max(dim=-1).values, 'foretell_pool2', 'foretell_pool3', 'foretell_pool4', 'foretell_avgpool', ) output_modifier.configure(config) return config, model