You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import torch
  2. import torch.nn.functional as F
  3. from torchvision import transforms
  4. from PIL import Image
  5. import numpy as np
  6. import requests
  7. import os
  8. import clip
  9. from sklearn.metrics import average_precision_score
  10. from torch.utils.data import DataLoader, Dataset
  11. from tqdm import tqdm
  12. import warnings
  13. warnings.filterwarnings('ignore')
  14. # Load DINO model and preprocess
  15. dino_variant = "dinov2_vitl14" # Change this to the desired DINO variant
  16. dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
  17. # Ensure the model is in evaluation mode
  18. dino_model.eval()
  19. dino_model = dino_model.to("cuda" if torch.cuda.is_available() else "cpu")
  20. # Define noise addition function
  21. def add_noise(image, std_dev=0.1):
  22. """Add Gaussian noise to the input image."""
  23. noise = torch.randn_like(image) * std_dev
  24. return image + noise
  25. # Define cosine similarity function
  26. def cosine_similarity(embedding1, embedding2):
  27. """Calculate cosine similarity between two embeddings."""
  28. return F.cosine_similarity(embedding1, embedding2, dim=-1)
  29. # Define dataset class
  30. class ImageDataset(Dataset):
  31. def __init__(self, folder, label, transform):
  32. self.folder = folder
  33. self.label = label
  34. self.transform = transform
  35. self.images = [
  36. os.path.join(folder, img)
  37. for img in os.listdir(folder)
  38. if os.path.isfile(os.path.join(folder, img)) and img.lower().endswith(('.png', '.jpg', '.jpeg', '.JPEG', '.JPG', '.PNG', '.bmp', '.tiff'))
  39. ]
  40. def __len__(self):
  41. return len(self.images)
  42. def __getitem__(self, idx):
  43. image_path = self.images[idx]
  44. try:
  45. image = Image.open(image_path).convert("RGB")
  46. image = self.transform(image)
  47. return image, self.label
  48. except Exception as e:
  49. print(f"Skipping corrupted image {image_path}: {e}")
  50. # Indicate that this item should be skipped by raising an IndexError
  51. raise IndexError
  52. # Define validation function with batching and GPU acceleration
  53. def validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=16):
  54. """Validate RIGID on two folders (fake and real images) and report metrics."""
  55. device = "cuda" if torch.cuda.is_available() else "cpu"
  56. transform = transforms.Compose([
  57. transforms.Resize((224, 224)), # Resize to model input size
  58. transforms.ToTensor(),
  59. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  60. ])
  61. fake_dataset = ImageDataset(fake_folder, label=1, transform=transform)
  62. real_dataset = ImageDataset(real_folder, label=0, transform=transform)
  63. dataset = torch.utils.data.ConcatDataset([fake_dataset, real_dataset])
  64. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  65. labels = []
  66. predictions = []
  67. for batch_images, batch_labels in tqdm(dataloader, desc="Validating", unit="batch"):
  68. try:
  69. batch_images = batch_images.to(device)
  70. batch_labels = batch_labels.to(device)
  71. with torch.no_grad():
  72. # Original embeddings
  73. original_embeddings = dino_model(batch_images).squeeze()
  74. # Noisy embeddings
  75. noisy_images = add_noise(batch_images, std_dev=noise_std)
  76. noisy_embeddings = dino_model(noisy_images).squeeze()
  77. # Cosine similarity
  78. similarities = cosine_similarity(original_embeddings, noisy_embeddings)
  79. # Predictions based on threshold
  80. batch_predictions = (similarities < threshold).long()
  81. labels.extend(batch_labels.cpu().numpy())
  82. predictions.extend(batch_predictions.cpu().numpy())
  83. except IndexError:
  84. # Skip batches with invalid images
  85. continue
  86. # Compute metrics
  87. fake_labels = [1] * len(fake_dataset)
  88. real_labels = [0] * len(real_dataset)
  89. all_labels = fake_labels + real_labels
  90. fake_accuracy = sum([p == l for p, l in zip(predictions[:len(fake_labels)], fake_labels)]) / len(fake_labels)
  91. real_accuracy = sum([p == l for p, l in zip(predictions[len(fake_labels):], real_labels)]) / len(real_labels)
  92. overall_accuracy = sum([p == l for p, l in zip(predictions, all_labels)]) / len(all_labels)
  93. average_precision = average_precision_score(all_labels, predictions)
  94. metrics = {
  95. "Fake Accuracy": fake_accuracy,
  96. "Real Accuracy": real_accuracy,
  97. "Overall Accuracy": overall_accuracy,
  98. "Average Precision": average_precision
  99. }
  100. return metrics
  101. # Example usage
  102. if __name__ == "__main__":
  103. # Example folders
  104. # fake_folder = "../../datasets/IMAGINET_unifd_test_6000/1_fake"
  105. # real_folder = "../../datasets/IMAGINET_unifd_test_6000/0_real"
  106. # fake_folder = "../../datasets/GenImage-tiny-all/1_fake"
  107. # real_folder = "../../datasets/GenImage-tiny-all/0_real"
  108. fake_folder = "../../datasets/ArtiFact_test_small/1_fake"
  109. real_folder = "../../datasets/ArtiFact_test_small/0_real"
  110. # Validate and print metrics
  111. metrics = validate(fake_folder, real_folder, noise_std=0.1, threshold=0.95, batch_size=256)
  112. print("Validation Metrics:", metrics)