You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ws_lap_inception.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import Tuple
  2. from collections import OrderedDict
  3. from ...configs.celeba_configs import CelebAConfigs, CelebATag
  4. from ...models.model import Model
  5. from ..entrypoint import BaseEntrypoint
  6. from ...models.celeba.lap_inception import CelebALAPInception
  7. from ...modules.lap import LAP
  8. from ...modules.adaptive_lap import AdaptiveLAP
  9. from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
  10. from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
  11. from ...utils.aux_output import AuxOutput
  12. from ...utils.output_modifier import OutputModifier
  13. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  14. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  15. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  16. from ...model_evaluation.loss_evaluator import LossEvaluator
  17. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  18. def lap_factory(channel):
  19. return LAP(
  20. channel, 2, 2, hidden_channels=[8], n_attention_heads=3,
  21. sigmoid_scale=0.1, discriminative_attention=True)
  22. def adaptive_lap_factory(channel):
  23. return AdaptiveLAP(channel, [], 0.1,
  24. n_attention_heads=3, discriminative_attention=True)
  25. class EntryPoint(BaseEntrypoint):
  26. def __init__(self, phase_type) -> None:
  27. self.active_min_ratio: float = 0.02
  28. self.active_max_ratio: float = 0.4
  29. self.inactive_ratio: float = 0.01
  30. self.common_max_ratio = 0.02
  31. super().__init__(phase_type)
  32. def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
  33. conf = CelebAConfigs('CelebA_Inception_WS', 10, 256, self.phase_type)
  34. conf.max_epochs = 12
  35. conf.main_tag = CelebATag.SmilingTag
  36. conf.tags = [conf.main_tag.name]
  37. model = CelebALAPInception(conf.main_tag.name, 0.4, lap_factory, adaptive_lap_factory)
  38. # aux loss for free head
  39. fws_losses = [
  40. DiscriminativeWeaklySupervisedLoss(
  41. title, model, att_score_layer, 0.025 / 3, 0,
  42. (0, self.common_max_ratio),
  43. {2: [0, 1]}, discr_score_layer,
  44. w_attention_in_ordering=1, w_discr_in_ordering=0)
  45. for title, att_score_layer, discr_score_layer in [
  46. ('FPool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
  47. ('FPool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
  48. ('FPool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
  49. ('FPoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  50. ]
  51. ]
  52. for fws_loss in fws_losses:
  53. fws_loss.configure(conf)
  54. # weakly supervised losses based on discriminative head and attention head
  55. ws_losses = [
  56. DiscriminativeWeaklySupervisedLoss(
  57. title, model, att_score_layer, 0.025 * 2 / 3,
  58. self.inactive_ratio,
  59. (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]},
  60. discr_score_layer,
  61. w_attention_in_ordering=0.2, w_discr_in_ordering=1)
  62. for title, att_score_layer, discr_score_layer in [
  63. ('Pool2', model.maxpool2.attention_layer, model.maxpool2.discrimination_layer),
  64. ('Pool6', model.Mixed_6a.pool.attention_layer, model.Mixed_6a.pool.discrimination_layer),
  65. ('Pool7', model.Mixed_7a.pool.attention_layer, model.Mixed_7a.pool.discrimination_layer),
  66. ('PoolG', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  67. ]
  68. ]
  69. for ws_loss in ws_losses:
  70. ws_loss.configure(conf)
  71. # concordance loss for attention
  72. concordance_loss = PoolConcordanceLossCalculator(
  73. 'AC', model, OrderedDict([
  74. ('att2', model.maxpool2.attention_layer),
  75. ('att6', model.Mixed_6a.pool.attention_layer),
  76. ('att7', model.Mixed_7a.pool.attention_layer),
  77. ('attG', model.avgpool[0].attention_layer),
  78. ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0,
  79. labels_by_channel={0: [0], 1: [1]})
  80. concordance_loss.configure(conf)
  81. # concordance loss for discrimination head
  82. concordance_loss2 = PoolConcordanceLossCalculator(
  83. 'DC', model, OrderedDict([
  84. ('D-att2', model.maxpool2.discrimination_layer),
  85. ('D-att6', model.Mixed_6a.pool.discrimination_layer),
  86. ('D-att7', model.Mixed_7a.pool.discrimination_layer),
  87. ('D-attG', model.avgpool[0].discrimination_layer),
  88. ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0,
  89. labels_by_channel={0: [0], 1: [1]})
  90. concordance_loss2.configure(conf)
  91. conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  92. 'b': BinaryEvaluator,
  93. 'l': LossEvaluator,
  94. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  95. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  96. }))
  97. conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  98. ###################################
  99. ########### Foreteller ############
  100. ###################################
  101. aux = AuxOutput(model, dict(
  102. foretell_pool2=model.maxpool2.attention_layer,
  103. foretell_pool6=model.Mixed_6a.pool.attention_layer,
  104. foretell_pool7=model.Mixed_7a.pool.attention_layer,
  105. foretell_avgpool=model.avgpool[0].attention_layer,
  106. ))
  107. aux.configure(conf)
  108. output_modifier = OutputModifier(model,
  109. lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1),
  110. 'foretell_pool2',
  111. 'foretell_pool6',
  112. 'foretell_pool7',
  113. 'foretell_avgpool',
  114. )
  115. output_modifier.configure(conf)
  116. return conf, model