from typing import Tuple from collections import OrderedDict from ...configs.celeba_configs import CelebAConfigs, CelebATag from ...models.model import Model from ..entrypoint import BaseEntrypoint from ...models.celeba.lap_inception import CelebALAPInception from ...modules.lap import LAP from ...modules.adaptive_lap import AdaptiveLAP from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss from ...utils.aux_output import AuxOutput from ...utils.output_modifier import OutputModifier from ...model_evaluation.binary_evaluator import BinaryEvaluator from ...model_evaluation.binary_fortelling import BinForetellerEvaluator from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator from ...model_evaluation.loss_evaluator import LossEvaluator from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator def lap_factory(channel): return LAP( channel, 2, 2, hidden_channels=[8], n_attention_heads=3, sigmoid_scale=0.1, discriminative_attention=True) def adaptive_lap_factory(channel): return AdaptiveLAP(channel, [], 0.1, n_attention_heads=3, discriminative_attention=True) class EntryPoint(BaseEntrypoint): def __init__(self, phase_type) -> None: self.active_min_ratio: float = 0.02 self.active_max_ratio: float = 0.4 self.inactive_ratio: float = 0.01 self.common_max_ratio = 0.02 super().__init__(phase_type) def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]: conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type) conf.max_epochs = 12 conf.main_tag = CelebATag.SmilingTag conf.tags = [conf.main_tag.name] model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory) # aux loss for free head fws_losses = [ DiscriminativeWeaklySupervisedLoss( title, model, att_score_layer, 0.025 / 3, 0, (0, self.common_max_ratio), {2: [0, 1]}, discr_score_layer, w_attention_in_ordering=1, w_discr_in_ordering=0) for title, att_score_layer, discr_score_layer in [ ('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), ('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), ('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), ('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), ] ] for fws_loss in fws_losses: fws_loss.configure(conf) # weakly supervised losses based on discriminative head and attention head ws_losses = [ DiscriminativeWeaklySupervisedLoss( title, model, att_score_layer, 0.025 * 2 / 3, self.inactive_ratio, (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]}, discr_score_layer, w_attention_in_ordering=0.2, w_discr_in_ordering=1) for title, att_score_layer, discr_score_layer in [ ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer), ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer), ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer), ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer), ] ] for ws_loss in ws_losses: ws_loss.configure(conf) # concordance loss for attention concordance_loss = PoolConcordanceLossCalculator( 'AC', model, OrderedDict([ ('att2', model.maxpool2.attention_layer), ('att6', model.Mixed_6a.pool.attention_layer), ('att7', model.Mixed_7a.pool.attention_layer), ('attG', model.avgpool[0].attention_layer), ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0, labels_by_channel={0: [0], 1: [1]}) concordance_loss.configure(conf) # concordance loss for discrimination head concordance_loss2 = PoolConcordanceLossCalculator( 'DC', model, OrderedDict([ ('D-att2', model.maxpool2.discrimination_layer), ('D-att6', model.Mixed_6a.pool.discrimination_layer), ('D-att7', model.Mixed_7a.pool.discrimination_layer), ('D-attG', model.avgpool[0].discrimination_layer), ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0, labels_by_channel={0: [0], 1: [1]}) concordance_loss2.configure(conf) conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({ 'b': BinaryEvaluator, 'l': LossEvaluator, 'f': BinForetellerEvaluator.standard_creator('foretell'), 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'), })) conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc' ################################### ########### Foreteller ############ ################################### aux = AuxOutput(model, dict( foretell_pool2=model.maxpool2.attention_layer, foretell_pool6=model.Mixed_6a.pool.attention_layer, foretell_pool7=model.Mixed_7a.pool.attention_layer, foretell_avgpool=model.avgpool[0].attention_layer, )) aux.configure(conf) output_modifier = OutputModifier(model, lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1), 'foretell_pool2', 'foretell_pool6', 'foretell_pool7', 'foretell_avgpool', ) output_modifier.configure(conf) return conf, model