You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import transforms as T
  5. import warnings
  6. warnings.filterwarnings('ignore')
  7. from mixstyle import MixStyle
  8. # Define Feature Extractor with Separate Branch for Cosine Similarity + MixStyle
  9. class HybridModel(nn.Module):
  10. def __init__(self, dino_model, clip_model, feature_dim, use_featup=False, use_mixstyle=False, cosine_branch=False, num_layers=4, output_dim=1):
  11. super(HybridModel, self).__init__()
  12. self.dino_model = dino_model
  13. self.clip_model = clip_model
  14. self.use_featup = use_featup
  15. self.use_mixstyle = use_mixstyle
  16. self.cosine_branch = cosine_branch
  17. feature_dim = 512
  18. # Load FeatUp upsamplers
  19. if use_featup:
  20. # self.dino_upsampler = torch.hub.load("mhamilton723/FeatUp", "dinov2", use_norm=True)
  21. self.clip_upsampler = torch.hub.load("mhamilton723/FeatUp", "clip", use_norm=True)
  22. # Freeze AND set to eval mode
  23. # self.dino_upsampler.eval()
  24. self.clip_upsampler.eval()
  25. # for param in self.dino_upsampler.parameters():
  26. # param.requires_grad = False
  27. for param in self.clip_upsampler.parameters():
  28. param.requires_grad = False
  29. # Add MixStyle
  30. if use_mixstyle:
  31. self.mixstyle = MixStyle(p=0.5, alpha=0.1, mix='random')
  32. # Build classifier based on number of layers
  33. classifier_layers = []
  34. if num_layers == 1:
  35. classifier_layers = [
  36. nn.Linear(feature_dim, output_dim)
  37. ]
  38. elif num_layers == 2:
  39. classifier_layers = [
  40. nn.Linear(feature_dim, 256),
  41. nn.ReLU(),
  42. nn.Linear(256, output_dim)
  43. ]
  44. elif num_layers == 3:
  45. classifier_layers = [
  46. nn.Linear(feature_dim, 512),
  47. nn.ReLU(),
  48. nn.Linear(512, 256),
  49. nn.ReLU(),
  50. nn.Linear(256, output_dim)
  51. ]
  52. elif num_layers == 4:
  53. classifier_layers = [
  54. nn.Linear(feature_dim, 512),
  55. nn.ReLU(),
  56. nn.Linear(512, 256),
  57. nn.ReLU(),
  58. nn.Linear(256, 64),
  59. nn.ReLU(),
  60. nn.Linear(64, output_dim)
  61. ]
  62. elif num_layers == 5:
  63. classifier_layers = [
  64. nn.Linear(feature_dim, 1024),
  65. nn.ReLU(),
  66. nn.Linear(1024, 512),
  67. nn.ReLU(),
  68. nn.Linear(512, 256),
  69. nn.ReLU(),
  70. nn.Linear(256, 64),
  71. nn.ReLU(),
  72. nn.Linear(64, output_dim)
  73. ]
  74. else:
  75. raise ValueError("num_layers must be between 1 and 5")
  76. self.classifier = nn.Sequential(*classifier_layers)
  77. if self.cosine_branch:
  78. self.similarity_branch = nn.Sequential(
  79. nn.Linear(1, 16),
  80. nn.ReLU(),
  81. nn.Linear(16, output_dim)
  82. )
  83. # Weighted sum of similarity and classifier outputs
  84. self.weighted_sum_layer = nn.Linear(2, 1)
  85. def forward_old(self, x):
  86. norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  87. x_normalized_featup = norm(x)
  88. # Apply FeatUp if enabled
  89. if self.use_featup:
  90. print("======= Using Featup for DINOv2 =======")
  91. dino_original_features = self.dino_upsampler(x_normalized_featup)
  92. # dino_noisy_features_upsampled = self.dino_upsampler(x)
  93. # Extract original and noisy features from DINO
  94. dino_original_features = self.dino_model(x).squeeze()
  95. # dino_noisy_features = self.dino_model(noisy_x).squeeze()
  96. # Apply MixStyle to DINO features
  97. if self.use_mixstyle:
  98. if self.use_featup:
  99. dino_original_features = self.mixstyle(dino_original_features)
  100. else:
  101. dino_original_features = self.mixstyle(dino_original_features)
  102. # Apply FeatUp on CLIP features if enabled
  103. if self.use_featup:
  104. print("======= Using Featup for CLIP =======")
  105. clip_features = self.clip_upsampler(x_normalized_featup)
  106. # clip_features = self.clip_model.encode_image(x)
  107. else:
  108. # Extract CLIP features
  109. clip_features = self.clip_model.encode_image(x)
  110. # Apply MixStyle to CLIP features
  111. if self.use_mixstyle:
  112. clip_features = self.mixstyle(clip_features)
  113. # # Compute cosine similarity
  114. # if self.use_featup:
  115. # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features_upsampled, dim=-1).unsqueeze(-1)
  116. # else:
  117. # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features, dim=-1).unsqueeze(-1)
  118. # Combine features
  119. combined_features = torch.cat((dino_original_features, clip_features), dim=-1)
  120. main_output = self.classifier(combined_features)
  121. # main_output = self.classifier(clip_features.unsqueeze(0))
  122. return main_output
  123. # # Similarity output
  124. # if self.cosine_branch:
  125. # similarity_output = self.similarity_branch(cosine_sim)
  126. # # Weighted sum
  127. # final_output = self.weighted_sum_layer(torch.cat((main_output, similarity_output), dim=-1))
  128. # return final_output
  129. # else:
  130. # return main_output
  131. def forward(self, x):
  132. # Normalize input for FeatUp compatibility
  133. norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  134. x_normalized = norm(x)
  135. # Extract CLIP features with FeatUp
  136. if self.use_featup:
  137. try:
  138. clip_features = self.clip_upsampler(x_normalized)
  139. # Global average pooling to convert spatial features to vector
  140. clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1)
  141. except Exception as e:
  142. print(f"CLIP upsampler failed: {e}, using standard features")
  143. clip_features = self.clip_model.encode_image(x_normalized)
  144. else:
  145. clip_features = self.clip_model.encode_image(x_normalized)
  146. # Use only CLIP features (no DINO concatenation)
  147. main_output = self.classifier(clip_features)
  148. return main_output
  149. def forward_featup_dino(self, x):
  150. # Normalize input for FeatUp compatibility
  151. norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  152. x_normalized = norm(x)
  153. # Extract DINO features with FeatUp
  154. if self.use_featup and self.dino_upsampler is not None:
  155. dino_features = self.dino_upsampler(x_normalized)
  156. # Global average pooling to convert spatial features to vector
  157. dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1)
  158. else:
  159. dino_features = self.dino_model(x).squeeze()
  160. # Use only DINO features (no CLIP concatenation)
  161. main_output = self.classifier(dino_features)
  162. return main_output
  163. def forward_featup_both(self, x):
  164. # Normalize input
  165. norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  166. x_normalized = norm(x)
  167. # Extract DINO features with FeatUp
  168. if self.use_featup:
  169. dino_features = self.dino_upsampler(x_normalized)
  170. # Global average pooling to convert to vector
  171. dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1)
  172. else:
  173. dino_features = self.dino_model(x).squeeze()
  174. # DINO already returns vector features, no pooling needed
  175. # Extract CLIP features
  176. if self.use_featup:
  177. clip_features = self.clip_upsampler(x_normalized)
  178. # Global average pooling for CLIP spatial features too
  179. clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1)
  180. else:
  181. clip_features = self.clip_model.encode_image(x_normalized)
  182. # Standard CLIP already returns vectors
  183. # Now both are vectors and can be concatenated
  184. combined_features = torch.cat((dino_features, clip_features), dim=-1)
  185. #print(f"DINO features shape: {dino_features.shape}")
  186. #print(f"CLIP features shape: {clip_features.shape}")
  187. #print(f"Combined features shape: {combined_features.shape}")
  188. main_output = self.classifier(combined_features)
  189. return main_output