123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- from typing import Tuple
- from collections import OrderedDict
- from ...configs.celeba_configs import CelebAConfigs, CelebATag
- from ...models.model import Model
- from ..entrypoint import BaseEntrypoint
- from ...models.celeba.lap_inception import CelebALAPInception
- from ...modules.lap import LAP
- from ...modules.adaptive_lap import AdaptiveLAP
- from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
- from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
- from ...utils.aux_output import AuxOutput
- from ...utils.output_modifier import OutputModifier
- from ...model_evaluation.binary_evaluator import BinaryEvaluator
- from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
- from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
- from ...model_evaluation.loss_evaluator import LossEvaluator
- from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
-
-
- def lap_factory(channel):
- return LAP(
- channel, 2, 2, hidden_channels=[8], n_attention_heads=3,
- sigmoid_scale=0.1, discriminative_attention=True)
-
-
- def adaptive_lap_factory(channel):
- return AdaptiveLAP(channel, [], 0.1,
- n_attention_heads=3, discriminative_attention=True)
-
-
- class EntryPoint(BaseEntrypoint):
-
- def __init__(self, phase_type) -> None:
- self.active_min_ratio: float = 0.02
- self.active_max_ratio: float = 0.4
- self.inactive_ratio: float = 0.01
- self.common_max_ratio = 0.02
-
- super().__init__(phase_type)
-
- def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
- conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type)
-
- conf.max_epochs = 12
- conf.main_tag = CelebATag.SmilingTag
- conf.tags = [conf.main_tag.name]
-
- model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory)
-
- # aux loss for free head
-
- fws_losses = [
- DiscriminativeWeaklySupervisedLoss(
- title, model, att_score_layer, 0.025 / 3, 0,
- (0, self.common_max_ratio),
- {2: [0, 1]}, discr_score_layer,
- w_attention_in_ordering=1, w_discr_in_ordering=0)
- for title, att_score_layer, discr_score_layer in [
- ('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
- ('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
- ('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
- ('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
- ]
- ]
- for fws_loss in fws_losses:
- fws_loss.configure(conf)
-
- # weakly supervised losses based on discriminative head and attention head
-
- ws_losses = [
- DiscriminativeWeaklySupervisedLoss(
- title, model, att_score_layer, 0.025 * 2 / 3,
- self.inactive_ratio,
- (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]},
- discr_score_layer,
- w_attention_in_ordering=0.2, w_discr_in_ordering=1)
- for title, att_score_layer, discr_score_layer in [
- ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
- ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
- ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
- ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
- ]
- ]
- for ws_loss in ws_losses:
- ws_loss.configure(conf)
-
- # concordance loss for attention
-
- concordance_loss = PoolConcordanceLossCalculator(
- 'AC', model, OrderedDict([
- ('att2', model.maxpool2.attention_layer),
- ('att6', model.Mixed_6a.pool.attention_layer),
- ('att7', model.Mixed_7a.pool.attention_layer),
- ('attG', model.avgpool[0].attention_layer),
- ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0,
- labels_by_channel={0: [0], 1: [1]})
- concordance_loss.configure(conf)
-
- # concordance loss for discrimination head
-
- concordance_loss2 = PoolConcordanceLossCalculator(
- 'DC', model, OrderedDict([
- ('D-att2', model.maxpool2.discrimination_layer),
- ('D-att6', model.Mixed_6a.pool.discrimination_layer),
- ('D-att7', model.Mixed_7a.pool.discrimination_layer),
- ('D-attG', model.avgpool[0].discrimination_layer),
- ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0,
- labels_by_channel={0: [0], 1: [1]})
- concordance_loss2.configure(conf)
-
-
- conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
- 'b': BinaryEvaluator,
- 'l': LossEvaluator,
- 'f': BinForetellerEvaluator.standard_creator('foretell'),
- 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
- }))
- conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
-
- ###################################
- ########### Foreteller ############
- ###################################
-
- aux = AuxOutput(model, dict(
- foretell_pool2=model.maxpool2.attention_layer,
- foretell_pool6=model.Mixed_6a.pool.attention_layer,
- foretell_pool7=model.Mixed_7a.pool.attention_layer,
- foretell_avgpool=model.avgpool[0].attention_layer,
- ))
- aux.configure(conf)
-
-
- output_modifier = OutputModifier(model,
- lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1),
- 'foretell_pool2',
- 'foretell_pool6',
- 'foretell_pool7',
- 'foretell_avgpool',
- )
- output_modifier.configure(conf)
-
- return conf, model
|