| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- import argparse
- import clip
- import open_clip
- import datetime
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from PIL import Image
- import os
- from tqdm import tqdm
- from sklearn.metrics import accuracy_score, average_precision_score
- import numpy as np
- import warnings
- warnings.filterwarnings('ignore')
- from dataset import ImageDataset
- from model import HybridModel
-
-
-
- # Utility function for calculating metrics
- def calculate_metrics(predictions, labels):
- predictions = (predictions > 0.5).astype(int)
- accuracy = accuracy_score(labels, predictions)
- avg_precision = average_precision_score(labels, predictions)
- return accuracy, avg_precision
-
-
- # Training function
- def train_model(args): #{ViT-L-14-quickgelu, dfn2b} {ViT-H-14-quickgelu, dfn5b}
- # Set device
- device_ids = [int(gpu) for gpu in args.gpu.split(",")]
- device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
- print(device)
-
- # Load CLIP and DINOv2 model
- dino_model = torch.hub.load('facebookresearch/dinov2',args.dino_variant)
- # clip_model, _ = clip.load(clip_variant, device)
- clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset)
-
- # Freeze/Finetune CLIP
- for name, param in clip_model.named_parameters():
- param.requires_grad = False
- if args.finetune_clip:
- for name, param in clip_model.named_parameters():
- if 'transformer.resblocks.21' in name: # Example: unfreeze last few blocks
- param.requires_grad = True
-
- print("Unfrozen layers in CLIP model:")
- for name, param in clip_model.named_parameters():
- if param.requires_grad:
- print(name)
-
- # Freeze/Finetune DINO
- for name, param in dino_model.named_parameters():
- param.requires_grad = False
- if args.finetune_dino:
- for name, param in dino_model.named_parameters():
- if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name:
- param.requires_grad = True
-
- print("Unfrozen layers in DINOv2 model:")
- for name, param in dino_model.named_parameters():
- if param.requires_grad:
- print(name)
-
- # Model initialization
- dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant]
- clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant]
- print("clip_dim:", clip_dim, "clip_dim:", dino_dim)
- feature_dim = dino_dim + clip_dim
- # feature_dim = clip_dim
- hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device)
-
- # Wrap model for multi-GPU if applicable
- if len(device_ids) > 1:
- print("Using Parallel NN (multiple GPUs)")
- hybrid_model = nn.DataParallel(hybrid_model, device_ids=device_ids)
-
- # Print trainable parameters
- print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad))
-
- # Data Preparation
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.RandomHorizontalFlip(p=0.5), # Random Flip
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color Jitter
- transforms.RandomRotation(degrees=10), # Random Rotation
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- # train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform)
- # train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform)
- # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
- # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
-
- train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob)
- train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob)
- test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob)
- test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob)
-
-
-
- train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset])
- test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
- test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
-
- # Optimizer and loss function
- optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr)
- criterion = nn.BCEWithLogitsLoss()
-
- # Training loop
- for epoch in range(args.epochs):
- hybrid_model.train()
- running_loss = 0
-
- for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch"):
- images, labels = images.to(device), labels.float().to(device)
-
- # Create noise-perturbed images
- # noisy_images = images + torch.randn_like(images) * args.noise_std
-
- optimizer.zero_grad()
- outputs = hybrid_model(images).squeeze()
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
-
- print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}")
-
- # Evaluation
- hybrid_model.eval()
- all_preds, all_labels = [], []
-
- with torch.no_grad():
- for images, labels in test_loader:
- images, labels = images.to(device), labels.float().to(device)
-
- # Create noise-perturbed images
- # noisy_images = images + torch.randn_like(images) * args.noise_std
-
- outputs = torch.sigmoid(hybrid_model(images).squeeze())
- all_preds.extend(outputs.cpu().numpy())
- all_labels.extend(labels.cpu().numpy())
-
- accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels))
- print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}")
-
- # Save model
- if (epoch + 1) % args.save_every == 0:
- save_path = os.path.join(args.save_model_path, f"ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth")
- torch.save(hybrid_model.state_dict(), save_path)
- print(f"Model saved at {save_path}")
-
-
-
- def save_args_to_file(args, save_path):
- """Save the parser arguments to a text file."""
- with open(save_path, "w") as f:
- for arg, value in vars(args).items():
- f.write(f"{arg}: {value}\n")
-
-
-
- def validate_model(args):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # Load models
- dino_model = torch.hub.load('facebookresearch/dinov2', args.dino_variant)
- clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset)
-
- # Init model wrapper
- dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant]
- clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant]
- feature_dim = dino_dim + clip_dim
- # feature_dim = clip_dim
- model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device)
-
- model.load_state_dict(torch.load(args.model_path, map_location=device))
- model.eval()
-
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
- # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
- test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur)
- test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur)
-
-
- test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
- test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
-
- all_preds, all_labels = [], []
-
- with torch.no_grad():
- for images, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
- images, labels = images.to(device), labels.float().to(device)
- # noisy_images = images + torch.randn_like(images) * args.noise_std
- outputs = torch.sigmoid(model(images).squeeze())
- all_preds.extend(outputs.cpu().numpy())
- all_labels.extend(labels.cpu().numpy())
-
- preds_bin = (np.array(all_preds) > 0.5).astype(int)
- labels_np = np.array(all_labels)
-
- overall_acc = accuracy_score(labels_np, preds_bin)
- avg_precision = average_precision_score(labels_np, all_preds)
- real_acc = accuracy_score(labels_np[labels_np == 0], preds_bin[labels_np == 0])
- fake_acc = accuracy_score(labels_np[labels_np == 1], preds_bin[labels_np == 1])
-
- print(f"[Validation] Accuracy: {overall_acc:.4f}, Avg Precision: {avg_precision:.4f}, Real Acc: {real_acc:.4f}, Fake Acc: {fake_acc:.4f}")
-
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection")
- parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory")
- parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory")
- parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory")
- parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory")
- parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
- parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
- parser.add_argument("--noise_std", type=float, default=0.1, help="added noise to DINO embeddings")
- parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
- parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')")
- parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs")
- parser.add_argument("--num_layers", type=int, default=4, help="number of layers in classifier head (1 to 5)")
- parser.add_argument("--save_model_path", type=str, default='/media/external_16TB_1/amirtaha_amanzadi/DINO', help="Path to save the model")
- parser.add_argument("--dino_variant", type=str, help="DINO variant", default="dinov2_vitl14")
- parser.add_argument("--clip_variant", type=str, help="CLIP variant", default="ViT-L-14")
- parser.add_argument("--clip_pretrained_dataset", type=str, help="CLIP pretrained dataset", default="datacomp_xl_s13b_b90k")
- parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training")
- parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training")
- parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training")
- parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training")
- parser.add_argument("--cosine_branch", action="store_true", help="Whether to use cosine branch during training")
- parser.add_argument("--jpeg", type=int, nargs='*', default=None, help="Apply JPEG compression with given quality factors [95, 90, 75, 50]")
- parser.add_argument("--blur", type=float, nargs='*', default=None, help="Apply Gaussian blur with given sigma values [1, 2, 3]")
- parser.add_argument("--aug_prob", type=float, default=0, help="Probability to apply JPEG/blur augmentations during training")
- parser.add_argument("--eval", action="store_true", help="Run validation only")
- parser.add_argument("--model_path", type=str, help="Path to the saved model for evaluation")
-
- args = parser.parse_args()
- print(args)
-
- # Set GPUs
- # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
-
- if args.eval:
- validate_model(args)
- else:
- date_now = datetime.datetime.now()
- args.save_model_path = args.save_model_path + f"/Log_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
- os.makedirs(args.save_model_path, exist_ok=True)
-
- # Save args to file in log directory
- args_file_path = os.path.join(args.save_model_path, "args.txt")
- save_args_to_file(args, args_file_path)
-
- train_model(args)
|