import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms as T import warnings warnings.filterwarnings('ignore') from mixstyle import MixStyle # Define Feature Extractor with Separate Branch for Cosine Similarity + MixStyle class HybridModel(nn.Module): def __init__(self, dino_model, clip_model, feature_dim, use_featup=False, use_mixstyle=False, cosine_branch=False, num_layers=4, output_dim=1): super(HybridModel, self).__init__() self.dino_model = dino_model self.clip_model = clip_model self.use_featup = use_featup self.use_mixstyle = use_mixstyle self.cosine_branch = cosine_branch feature_dim = 512 # Load FeatUp upsamplers if use_featup: # self.dino_upsampler = torch.hub.load("mhamilton723/FeatUp", "dinov2", use_norm=True) self.clip_upsampler = torch.hub.load("mhamilton723/FeatUp", "clip", use_norm=True) # Freeze AND set to eval mode # self.dino_upsampler.eval() self.clip_upsampler.eval() # for param in self.dino_upsampler.parameters(): # param.requires_grad = False for param in self.clip_upsampler.parameters(): param.requires_grad = False # Add MixStyle if use_mixstyle: self.mixstyle = MixStyle(p=0.5, alpha=0.1, mix='random') # Build classifier based on number of layers classifier_layers = [] if num_layers == 1: classifier_layers = [ nn.Linear(feature_dim, output_dim) ] elif num_layers == 2: classifier_layers = [ nn.Linear(feature_dim, 256), nn.ReLU(), nn.Linear(256, output_dim) ] elif num_layers == 3: classifier_layers = [ nn.Linear(feature_dim, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, output_dim) ] elif num_layers == 4: classifier_layers = [ nn.Linear(feature_dim, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, output_dim) ] elif num_layers == 5: classifier_layers = [ nn.Linear(feature_dim, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, output_dim) ] else: raise ValueError("num_layers must be between 1 and 5") self.classifier = nn.Sequential(*classifier_layers) if self.cosine_branch: self.similarity_branch = nn.Sequential( nn.Linear(1, 16), nn.ReLU(), nn.Linear(16, output_dim) ) # Weighted sum of similarity and classifier outputs self.weighted_sum_layer = nn.Linear(2, 1) def forward_old(self, x): norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) x_normalized_featup = norm(x) # Apply FeatUp if enabled if self.use_featup: print("======= Using Featup for DINOv2 =======") dino_original_features = self.dino_upsampler(x_normalized_featup) # dino_noisy_features_upsampled = self.dino_upsampler(x) # Extract original and noisy features from DINO dino_original_features = self.dino_model(x).squeeze() # dino_noisy_features = self.dino_model(noisy_x).squeeze() # Apply MixStyle to DINO features if self.use_mixstyle: if self.use_featup: dino_original_features = self.mixstyle(dino_original_features) else: dino_original_features = self.mixstyle(dino_original_features) # Apply FeatUp on CLIP features if enabled if self.use_featup: print("======= Using Featup for CLIP =======") clip_features = self.clip_upsampler(x_normalized_featup) # clip_features = self.clip_model.encode_image(x) else: # Extract CLIP features clip_features = self.clip_model.encode_image(x) # Apply MixStyle to CLIP features if self.use_mixstyle: clip_features = self.mixstyle(clip_features) # # Compute cosine similarity # if self.use_featup: # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features_upsampled, dim=-1).unsqueeze(-1) # else: # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features, dim=-1).unsqueeze(-1) # Combine features combined_features = torch.cat((dino_original_features, clip_features), dim=-1) main_output = self.classifier(combined_features) # main_output = self.classifier(clip_features.unsqueeze(0)) return main_output # # Similarity output # if self.cosine_branch: # similarity_output = self.similarity_branch(cosine_sim) # # Weighted sum # final_output = self.weighted_sum_layer(torch.cat((main_output, similarity_output), dim=-1)) # return final_output # else: # return main_output def forward(self, x): # Normalize input for FeatUp compatibility norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) x_normalized = norm(x) # Extract CLIP features with FeatUp if self.use_featup: try: clip_features = self.clip_upsampler(x_normalized) # Global average pooling to convert spatial features to vector clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1) except Exception as e: print(f"CLIP upsampler failed: {e}, using standard features") clip_features = self.clip_model.encode_image(x_normalized) else: clip_features = self.clip_model.encode_image(x_normalized) # Use only CLIP features (no DINO concatenation) main_output = self.classifier(clip_features) return main_output def forward_featup_dino(self, x): # Normalize input for FeatUp compatibility norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) x_normalized = norm(x) # Extract DINO features with FeatUp if self.use_featup and self.dino_upsampler is not None: dino_features = self.dino_upsampler(x_normalized) # Global average pooling to convert spatial features to vector dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1) else: dino_features = self.dino_model(x).squeeze() # Use only DINO features (no CLIP concatenation) main_output = self.classifier(dino_features) return main_output def forward_featup_both(self, x): # Normalize input norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) x_normalized = norm(x) # Extract DINO features with FeatUp if self.use_featup: dino_features = self.dino_upsampler(x_normalized) # Global average pooling to convert to vector dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1) else: dino_features = self.dino_model(x).squeeze() # DINO already returns vector features, no pooling needed # Extract CLIP features if self.use_featup: clip_features = self.clip_upsampler(x_normalized) # Global average pooling for CLIP spatial features too clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1) else: clip_features = self.clip_model.encode_image(x_normalized) # Standard CLIP already returns vectors # Now both are vectors and can be concatenated combined_features = torch.cat((dino_features, clip_features), dim=-1) #print(f"DINO features shape: {dino_features.shape}") #print(f"CLIP features shape: {clip_features.shape}") #print(f"Combined features shape: {combined_features.shape}") main_output = self.classifier(combined_features) return main_output