You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ws_lap_resnet.py 5.6KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from typing import Tuple
  2. from collections import OrderedDict
  3. from ...configs.celeba_configs import CelebAConfigs, CelebATag
  4. from ...models.model import Model
  5. from ..entrypoint import BaseEntrypoint
  6. from ...models.celeba.lap_resnet import CelebALAPResNet18
  7. from ...criteria.cw_concordance_loss import PoolConcordanceLossCalculator
  8. from ...criteria.weakly_supervised import DiscriminativeWeaklySupervisedLoss
  9. from ...utils.aux_output import AuxOutput
  10. from ...utils.output_modifier import OutputModifier
  11. from ...model_evaluation.binary_evaluator import BinaryEvaluator
  12. from ...model_evaluation.binary_fortelling import BinForetellerEvaluator
  13. from ...model_evaluation.binary_faithfulness import BinFaithfulnessEvaluator
  14. from ...model_evaluation.loss_evaluator import LossEvaluator
  15. from ...model_evaluation.multieval_evaluator import MultiEvaluatorEvaluator
  16. class EntryPoint(BaseEntrypoint):
  17. def __init__(self, phase_type) -> None:
  18. self.active_min_ratio: float = 0.02
  19. self.active_max_ratio: float = 0.4
  20. self.inactive_ratio: float = 0.01
  21. self.common_max_ratio = 0.02
  22. super().__init__(phase_type)
  23. def _get_conf_model(self) -> Tuple[CelebAConfigs, Model]:
  24. conf = CelebAConfigs('CelebA_ResNet_WS', 11, 224, self.phase_type)
  25. conf.max_epochs = 12
  26. conf.main_tag = CelebATag.SmilingTag
  27. conf.tags = [conf.main_tag.name]
  28. model = CelebALAPResNet18(conf.main_tag.name, sigmoid_scale=0.1)
  29. # aux loss for free head
  30. fws_losses = [
  31. DiscriminativeWeaklySupervisedLoss(
  32. title, model, att_score_layer, 0.025 / 3, 0,
  33. (0, self.common_max_ratio),
  34. {2: [0, 1]}, discr_score_layer,
  35. w_attention_in_ordering=1, w_discr_in_ordering=0)
  36. for title, att_score_layer, discr_score_layer in [
  37. ('Fatt2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
  38. ('Fatt3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
  39. ('Fatt4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
  40. ('Fatt5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  41. ]
  42. ]
  43. for fws_loss in fws_losses:
  44. fws_loss.configure(conf)
  45. # weakly supervised losses based on discriminative head and attention head
  46. ws_losses = [
  47. DiscriminativeWeaklySupervisedLoss(
  48. title, model, att_score_layer, 0.025 * 2 / 3,
  49. self.inactive_ratio,
  50. (self.active_min_ratio, self.active_max_ratio), {0: [0], 1: [1]},
  51. discr_score_layer,
  52. w_attention_in_ordering=0.2, w_discr_in_ordering=1)
  53. for title, att_score_layer, discr_score_layer in [
  54. ('att2', model.layer2[0].pool.attention_layer, model.layer2[0].pool.discrimination_layer),
  55. ('att3', model.layer3[0].pool.attention_layer, model.layer3[0].pool.discrimination_layer),
  56. ('att4', model.layer4[0].pool.attention_layer, model.layer4[0].pool.discrimination_layer),
  57. ('att5', model.avgpool[0].attention_layer, model.avgpool[0].discrimination_layer),
  58. ]
  59. ]
  60. for ws_loss in ws_losses:
  61. ws_loss.configure(conf)
  62. # concordance loss for attention
  63. concordance_loss = PoolConcordanceLossCalculator(
  64. 'AC', model, OrderedDict([
  65. ('att2', model.layer2[0].pool.attention_layer),
  66. ('att3', model.layer3[0].pool.attention_layer),
  67. ('att4', model.layer4[0].pool.attention_layer),
  68. ('att5', model.avgpool[0].attention_layer),
  69. ]), loss_weight=1, weights=0.1 / 4, diff_thresholds=0,
  70. labels_by_channel={0: [0], 1: [1]})
  71. concordance_loss.configure(conf)
  72. # concordance loss for discrimination head
  73. concordance_loss2 = PoolConcordanceLossCalculator(
  74. 'DC', model, OrderedDict([
  75. ('att2', model.layer2[0].pool.discrimination_layer),
  76. ('att3', model.layer3[0].pool.discrimination_layer),
  77. ('att4', model.layer4[0].pool.discrimination_layer),
  78. ('att5', model.avgpool[0].discrimination_layer),
  79. ]), loss_weight=1, weights=0.05 / 4, diff_thresholds=0,
  80. labels_by_channel={0: [0], 1: [1]})
  81. concordance_loss2.configure(conf)
  82. conf.evaluator_cls = MultiEvaluatorEvaluator.create_standard_multi_evaluator_evaluator_maker(OrderedDict({
  83. 'b': BinaryEvaluator,
  84. 'l': LossEvaluator,
  85. 'f': BinForetellerEvaluator.standard_creator('foretell'),
  86. 'bf': BinFaithfulnessEvaluator.standard_creator('foretell'),
  87. }))
  88. conf.title_of_reference_metric_to_choose_best_epoch = 'b_BAcc'
  89. ###################################
  90. ########### Foreteller ############
  91. ###################################
  92. aux = AuxOutput(model, dict(
  93. foretell_pool2=model.layer2[0].pool.attention_layer,
  94. foretell_pool3=model.layer3[0].pool.attention_layer,
  95. foretell_pool4=model.layer4[0].pool.attention_layer,
  96. foretell_avgpool=model.avgpool[0].attention_layer,
  97. ))
  98. aux.configure(conf)
  99. output_modifier = OutputModifier(model,
  100. lambda x: (x - 0.5).relu().flatten(2).sum(dim=2)[:, :2].argmax(dim=1),
  101. 'foretell_pool2',
  102. 'foretell_pool3',
  103. 'foretell_pool4',
  104. 'foretell_avgpool',
  105. )
  106. output_modifier.configure(conf)
  107. return conf, model