| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchvision import transforms as T
- import warnings
- warnings.filterwarnings('ignore')
-
- from mixstyle import MixStyle
-
-
- # Define Feature Extractor with Separate Branch for Cosine Similarity + MixStyle
- class HybridModel(nn.Module):
- def __init__(self, dino_model, clip_model, feature_dim, use_featup=False, use_mixstyle=False, cosine_branch=False, num_layers=4, output_dim=1):
- super(HybridModel, self).__init__()
-
- self.dino_model = dino_model
- self.clip_model = clip_model
- self.use_featup = use_featup
- self.use_mixstyle = use_mixstyle
- self.cosine_branch = cosine_branch
- feature_dim = 512
-
-
- # Load FeatUp upsamplers
- if use_featup:
- # self.dino_upsampler = torch.hub.load("mhamilton723/FeatUp", "dinov2", use_norm=True)
- self.clip_upsampler = torch.hub.load("mhamilton723/FeatUp", "clip", use_norm=True)
-
- # Freeze AND set to eval mode
- # self.dino_upsampler.eval()
- self.clip_upsampler.eval()
- # for param in self.dino_upsampler.parameters():
- # param.requires_grad = False
- for param in self.clip_upsampler.parameters():
- param.requires_grad = False
-
-
- # Add MixStyle
- if use_mixstyle:
- self.mixstyle = MixStyle(p=0.5, alpha=0.1, mix='random')
-
-
- # Build classifier based on number of layers
- classifier_layers = []
- if num_layers == 1:
- classifier_layers = [
- nn.Linear(feature_dim, output_dim)
- ]
- elif num_layers == 2:
- classifier_layers = [
- nn.Linear(feature_dim, 256),
- nn.ReLU(),
- nn.Linear(256, output_dim)
- ]
- elif num_layers == 3:
- classifier_layers = [
- nn.Linear(feature_dim, 512),
- nn.ReLU(),
- nn.Linear(512, 256),
- nn.ReLU(),
- nn.Linear(256, output_dim)
- ]
- elif num_layers == 4:
- classifier_layers = [
- nn.Linear(feature_dim, 512),
- nn.ReLU(),
- nn.Linear(512, 256),
- nn.ReLU(),
- nn.Linear(256, 64),
- nn.ReLU(),
- nn.Linear(64, output_dim)
- ]
- elif num_layers == 5:
- classifier_layers = [
- nn.Linear(feature_dim, 1024),
- nn.ReLU(),
- nn.Linear(1024, 512),
- nn.ReLU(),
- nn.Linear(512, 256),
- nn.ReLU(),
- nn.Linear(256, 64),
- nn.ReLU(),
- nn.Linear(64, output_dim)
- ]
- else:
- raise ValueError("num_layers must be between 1 and 5")
-
- self.classifier = nn.Sequential(*classifier_layers)
-
-
- if self.cosine_branch:
- self.similarity_branch = nn.Sequential(
- nn.Linear(1, 16),
- nn.ReLU(),
- nn.Linear(16, output_dim)
- )
-
- # Weighted sum of similarity and classifier outputs
- self.weighted_sum_layer = nn.Linear(2, 1)
-
-
-
- def forward_old(self, x):
-
- norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- x_normalized_featup = norm(x)
-
- # Apply FeatUp if enabled
- if self.use_featup:
- print("======= Using Featup for DINOv2 =======")
- dino_original_features = self.dino_upsampler(x_normalized_featup)
- # dino_noisy_features_upsampled = self.dino_upsampler(x)
-
- # Extract original and noisy features from DINO
- dino_original_features = self.dino_model(x).squeeze()
-
- # dino_noisy_features = self.dino_model(noisy_x).squeeze()
-
- # Apply MixStyle to DINO features
- if self.use_mixstyle:
- if self.use_featup:
- dino_original_features = self.mixstyle(dino_original_features)
- else:
- dino_original_features = self.mixstyle(dino_original_features)
-
-
-
- # Apply FeatUp on CLIP features if enabled
- if self.use_featup:
- print("======= Using Featup for CLIP =======")
- clip_features = self.clip_upsampler(x_normalized_featup)
- # clip_features = self.clip_model.encode_image(x)
- else:
- # Extract CLIP features
- clip_features = self.clip_model.encode_image(x)
-
- # Apply MixStyle to CLIP features
- if self.use_mixstyle:
- clip_features = self.mixstyle(clip_features)
-
-
-
- # # Compute cosine similarity
- # if self.use_featup:
- # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features_upsampled, dim=-1).unsqueeze(-1)
- # else:
- # cosine_sim = F.cosine_similarity(dino_original_features, dino_noisy_features, dim=-1).unsqueeze(-1)
-
-
-
- # Combine features
- combined_features = torch.cat((dino_original_features, clip_features), dim=-1)
- main_output = self.classifier(combined_features)
-
- # main_output = self.classifier(clip_features.unsqueeze(0))
- return main_output
-
- # # Similarity output
- # if self.cosine_branch:
- # similarity_output = self.similarity_branch(cosine_sim)
- # # Weighted sum
- # final_output = self.weighted_sum_layer(torch.cat((main_output, similarity_output), dim=-1))
- # return final_output
-
- # else:
- # return main_output
-
-
-
-
- def forward(self, x):
- # Normalize input for FeatUp compatibility
- norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- x_normalized = norm(x)
-
- # Extract CLIP features with FeatUp
- if self.use_featup:
- try:
- clip_features = self.clip_upsampler(x_normalized)
- # Global average pooling to convert spatial features to vector
- clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1)
- except Exception as e:
- print(f"CLIP upsampler failed: {e}, using standard features")
- clip_features = self.clip_model.encode_image(x_normalized)
- else:
- clip_features = self.clip_model.encode_image(x_normalized)
-
- # Use only CLIP features (no DINO concatenation)
- main_output = self.classifier(clip_features)
-
- return main_output
-
-
-
- def forward_featup_dino(self, x):
- # Normalize input for FeatUp compatibility
- norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- x_normalized = norm(x)
-
- # Extract DINO features with FeatUp
- if self.use_featup and self.dino_upsampler is not None:
- dino_features = self.dino_upsampler(x_normalized)
- # Global average pooling to convert spatial features to vector
- dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1)
- else:
- dino_features = self.dino_model(x).squeeze()
-
- # Use only DINO features (no CLIP concatenation)
- main_output = self.classifier(dino_features)
-
- return main_output
-
-
-
- def forward_featup_both(self, x):
- # Normalize input
- norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- x_normalized = norm(x)
-
- # Extract DINO features with FeatUp
- if self.use_featup:
- dino_features = self.dino_upsampler(x_normalized)
- # Global average pooling to convert to vector
- dino_features = F.adaptive_avg_pool2d(dino_features, (1, 1)).squeeze(-1).squeeze(-1)
- else:
- dino_features = self.dino_model(x).squeeze()
- # DINO already returns vector features, no pooling needed
-
- # Extract CLIP features
- if self.use_featup:
- clip_features = self.clip_upsampler(x_normalized)
- # Global average pooling for CLIP spatial features too
- clip_features = F.adaptive_avg_pool2d(clip_features, (1, 1)).squeeze(-1).squeeze(-1)
- else:
- clip_features = self.clip_model.encode_image(x_normalized)
- # Standard CLIP already returns vectors
-
- # Now both are vectors and can be concatenated
- combined_features = torch.cat((dino_features, clip_features), dim=-1)
-
- #print(f"DINO features shape: {dino_features.shape}")
- #print(f"CLIP features shape: {clip_features.shape}")
- #print(f"Combined features shape: {combined_features.shape}")
-
-
- main_output = self.classifier(combined_features)
-
- return main_output
|