| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import torch
- import torch.nn.functional as F
- from torchvision import transforms
- from PIL import Image
- import numpy as np
- import requests
- import os
- import clip
- from sklearn.metrics import average_precision_score
- from torch.utils.data import DataLoader, Dataset
- from tqdm import tqdm
- import warnings
- warnings.filterwarnings('ignore')
-
-
- # Load DINO model and preprocess
- dino_variant = "dinov2_vitl14" # Change this to the desired DINO variant
- dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
-
- # Ensure the model is in evaluation mode
- dino_model.eval()
- dino_model = dino_model.to("cuda" if torch.cuda.is_available() else "cpu")
-
- # Define noise addition function
- def add_noise(image, std_dev=0.1):
- """Add Gaussian noise to the input image."""
- noise = torch.randn_like(image) * std_dev
- return image + noise
-
- # Define cosine similarity function
- def cosine_similarity(embedding1, embedding2):
- """Calculate cosine similarity between two embeddings."""
- return F.cosine_similarity(embedding1, embedding2, dim=-1)
-
- # Define dataset class
- class ImageDataset(Dataset):
- def __init__(self, folder, label, transform):
- self.folder = folder
- self.label = label
- self.transform = transform
- self.images = [
- os.path.join(folder, img)
- for img in os.listdir(folder)
- if os.path.isfile(os.path.join(folder, img)) and img.lower().endswith(('.png', '.jpg', '.jpeg', '.JPEG', '.JPG', '.PNG', '.bmp', '.tiff'))
- ]
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- image_path = self.images[idx]
- try:
- image = Image.open(image_path).convert("RGB")
- image = self.transform(image)
- return image, self.label
- except Exception as e:
- print(f"Skipping corrupted image {image_path}: {e}")
- # Indicate that this item should be skipped by raising an IndexError
- raise IndexError
-
- # Define validation function with batching and GPU acceleration
- def validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=16):
- """Validate RIGID on two folders (fake and real images) and report metrics."""
- device = "cuda" if torch.cuda.is_available() else "cpu"
-
- transform = transforms.Compose([
- transforms.Resize((224, 224)), # Resize to model input size
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- fake_dataset = ImageDataset(fake_folder, label=1, transform=transform)
- real_dataset = ImageDataset(real_folder, label=0, transform=transform)
-
- dataset = torch.utils.data.ConcatDataset([fake_dataset, real_dataset])
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
-
- labels = []
- predictions = []
-
- for batch_images, batch_labels in tqdm(dataloader, desc="Validating", unit="batch"):
- try:
- batch_images = batch_images.to(device)
- batch_labels = batch_labels.to(device)
-
- with torch.no_grad():
- # Original embeddings
- original_embeddings = dino_model(batch_images).squeeze()
-
- # Noisy embeddings
- noisy_images = add_noise(batch_images, std_dev=noise_std)
- noisy_embeddings = dino_model(noisy_images).squeeze()
-
- # Cosine similarity
- similarities = cosine_similarity(original_embeddings, noisy_embeddings)
-
- # Predictions based on threshold
- batch_predictions = (similarities < threshold).long()
-
- labels.extend(batch_labels.cpu().numpy())
- predictions.extend(batch_predictions.cpu().numpy())
-
- except IndexError:
- # Skip batches with invalid images
- continue
-
- # Compute metrics
- fake_labels = [1] * len(fake_dataset)
- real_labels = [0] * len(real_dataset)
- all_labels = fake_labels + real_labels
-
- fake_accuracy = sum([p == l for p, l in zip(predictions[:len(fake_labels)], fake_labels)]) / len(fake_labels)
- real_accuracy = sum([p == l for p, l in zip(predictions[len(fake_labels):], real_labels)]) / len(real_labels)
- overall_accuracy = sum([p == l for p, l in zip(predictions, all_labels)]) / len(all_labels)
- average_precision = average_precision_score(all_labels, predictions)
-
- metrics = {
- "Fake Accuracy": fake_accuracy,
- "Real Accuracy": real_accuracy,
- "Overall Accuracy": overall_accuracy,
- "Average Precision": average_precision
- }
-
- return metrics
-
- # Example usage
- if __name__ == "__main__":
- # Example folders
- # fake_folder = "../../datasets/IMAGINET_unifd_test_6000/1_fake"
- # real_folder = "../../datasets/IMAGINET_unifd_test_6000/0_real"
-
- # fake_folder = "../../datasets/GenImage-tiny-all/1_fake"
- # real_folder = "../../datasets/GenImage-tiny-all/0_real"
-
- fake_folder = "../../datasets/ArtiFact_test_small/1_fake"
- real_folder = "../../datasets/ArtiFact_test_small/0_real"
-
- # Validate and print metrics
- metrics = validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=256)
- print("Validation Metrics:", metrics)
-
|