import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np import requests import os import clip from sklearn.metrics import average_precision_score from torch.utils.data import DataLoader, Dataset from tqdm import tqdm import warnings warnings.filterwarnings('ignore') # Load DINO model and preprocess dino_variant = "dinov2_vitl14" # Change this to the desired DINO variant dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant) # Ensure the model is in evaluation mode dino_model.eval() dino_model = dino_model.to("cuda" if torch.cuda.is_available() else "cpu") # Define noise addition function def add_noise(image, std_dev=0.1): """Add Gaussian noise to the input image.""" noise = torch.randn_like(image) * std_dev return image + noise # Define cosine similarity function def cosine_similarity(embedding1, embedding2): """Calculate cosine similarity between two embeddings.""" return F.cosine_similarity(embedding1, embedding2, dim=-1) # Define dataset class class ImageDataset(Dataset): def __init__(self, folder, label, transform): self.folder = folder self.label = label self.transform = transform self.images = [ os.path.join(folder, img) for img in os.listdir(folder) if os.path.isfile(os.path.join(folder, img)) and img.lower().endswith(('.png', '.jpg', '.jpeg', '.JPEG', '.JPG', '.PNG', '.bmp', '.tiff')) ] def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] try: image = Image.open(image_path).convert("RGB") image = self.transform(image) return image, self.label except Exception as e: print(f"Skipping corrupted image {image_path}: {e}") # Indicate that this item should be skipped by raising an IndexError raise IndexError # Define validation function with batching and GPU acceleration def validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=16): """Validate RIGID on two folders (fake and real images) and report metrics.""" device = "cuda" if torch.cuda.is_available() else "cpu" transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to model input size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) fake_dataset = ImageDataset(fake_folder, label=1, transform=transform) real_dataset = ImageDataset(real_folder, label=0, transform=transform) dataset = torch.utils.data.ConcatDataset([fake_dataset, real_dataset]) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) labels = [] predictions = [] for batch_images, batch_labels in tqdm(dataloader, desc="Validating", unit="batch"): try: batch_images = batch_images.to(device) batch_labels = batch_labels.to(device) with torch.no_grad(): # Original embeddings original_embeddings = dino_model(batch_images).squeeze() # Noisy embeddings noisy_images = add_noise(batch_images, std_dev=noise_std) noisy_embeddings = dino_model(noisy_images).squeeze() # Cosine similarity similarities = cosine_similarity(original_embeddings, noisy_embeddings) # Predictions based on threshold batch_predictions = (similarities < threshold).long() labels.extend(batch_labels.cpu().numpy()) predictions.extend(batch_predictions.cpu().numpy()) except IndexError: # Skip batches with invalid images continue # Compute metrics fake_labels = [1] * len(fake_dataset) real_labels = [0] * len(real_dataset) all_labels = fake_labels + real_labels fake_accuracy = sum([p == l for p, l in zip(predictions[:len(fake_labels)], fake_labels)]) / len(fake_labels) real_accuracy = sum([p == l for p, l in zip(predictions[len(fake_labels):], real_labels)]) / len(real_labels) overall_accuracy = sum([p == l for p, l in zip(predictions, all_labels)]) / len(all_labels) average_precision = average_precision_score(all_labels, predictions) metrics = { "Fake Accuracy": fake_accuracy, "Real Accuracy": real_accuracy, "Overall Accuracy": overall_accuracy, "Average Precision": average_precision } return metrics # Example usage if __name__ == "__main__": # Example folders # fake_folder = "../../datasets/IMAGINET_unifd_test_6000/1_fake" # real_folder = "../../datasets/IMAGINET_unifd_test_6000/0_real" # fake_folder = "../../datasets/GenImage-tiny-all/1_fake" # real_folder = "../../datasets/GenImage-tiny-all/0_real" fake_folder = "../../datasets/ArtiFact_test_small/1_fake" real_folder = "../../datasets/ArtiFact_test_small/0_real" # Validate and print metrics metrics = validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=256) print("Validation Metrics:", metrics)