import argparse import clip import open_clip import datetime import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image import os from tqdm import tqdm from sklearn.metrics import accuracy_score, average_precision_score import numpy as np import warnings warnings.filterwarnings('ignore') from dataset import ImageDataset from model import HybridModel # Utility function for calculating metrics def calculate_metrics(predictions, labels): predictions = (predictions > 0.5).astype(int) accuracy = accuracy_score(labels, predictions) avg_precision = average_precision_score(labels, predictions) return accuracy, avg_precision # Training function def train_model(args): #{ViT-L-14-quickgelu, dfn2b} {ViT-H-14-quickgelu, dfn5b} # Set device device_ids = [int(gpu) for gpu in args.gpu.split(",")] device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu") print(device) # Load CLIP and DINOv2 model dino_model = torch.hub.load('facebookresearch/dinov2',args.dino_variant) # clip_model, _ = clip.load(clip_variant, device) clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset) # Freeze/Finetune CLIP for name, param in clip_model.named_parameters(): param.requires_grad = False if args.finetune_clip: for name, param in clip_model.named_parameters(): if 'transformer.resblocks.21' in name: # Example: unfreeze last few blocks param.requires_grad = True print("Unfrozen layers in CLIP model:") for name, param in clip_model.named_parameters(): if param.requires_grad: print(name) # Freeze/Finetune DINO for name, param in dino_model.named_parameters(): param.requires_grad = False if args.finetune_dino: for name, param in dino_model.named_parameters(): if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name: param.requires_grad = True print("Unfrozen layers in DINOv2 model:") for name, param in dino_model.named_parameters(): if param.requires_grad: print(name) # Model initialization dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant] clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant] print("clip_dim:", clip_dim, "clip_dim:", dino_dim) feature_dim = dino_dim + clip_dim # feature_dim = clip_dim hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device) # Wrap model for multi-GPU if applicable if len(device_ids) > 1: print("Using Parallel NN (multiple GPUs)") hybrid_model = nn.DataParallel(hybrid_model, device_ids=device_ids) # Print trainable parameters print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad)) # Data Preparation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), # Random Flip transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color Jitter transforms.RandomRotation(degrees=10), # Random Rotation transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform) # train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform) # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform) # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform) train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob) train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob) test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob) test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob) train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset]) test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset]) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) # Optimizer and loss function optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr) criterion = nn.BCEWithLogitsLoss() # Training loop for epoch in range(args.epochs): hybrid_model.train() running_loss = 0 for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch"): images, labels = images.to(device), labels.float().to(device) # Create noise-perturbed images # noisy_images = images + torch.randn_like(images) * args.noise_std optimizer.zero_grad() outputs = hybrid_model(images).squeeze() loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}") # Evaluation hybrid_model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.float().to(device) # Create noise-perturbed images # noisy_images = images + torch.randn_like(images) * args.noise_std outputs = torch.sigmoid(hybrid_model(images).squeeze()) all_preds.extend(outputs.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels)) print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}") # Save model if (epoch + 1) % args.save_every == 0: save_path = os.path.join(args.save_model_path, f"ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth") torch.save(hybrid_model.state_dict(), save_path) print(f"Model saved at {save_path}") def save_args_to_file(args, save_path): """Save the parser arguments to a text file.""" with open(save_path, "w") as f: for arg, value in vars(args).items(): f.write(f"{arg}: {value}\n") def validate_model(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load models dino_model = torch.hub.load('facebookresearch/dinov2', args.dino_variant) clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset) # Init model wrapper dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant] clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant] feature_dim = dino_dim + clip_dim # feature_dim = clip_dim model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device) model.load_state_dict(torch.load(args.model_path, map_location=device)) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform) # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform) test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur) test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur) test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset]) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) all_preds, all_labels = [], [] with torch.no_grad(): for images, labels in tqdm(test_loader, desc="Evaluating", unit="batch"): images, labels = images.to(device), labels.float().to(device) # noisy_images = images + torch.randn_like(images) * args.noise_std outputs = torch.sigmoid(model(images).squeeze()) all_preds.extend(outputs.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) preds_bin = (np.array(all_preds) > 0.5).astype(int) labels_np = np.array(all_labels) overall_acc = accuracy_score(labels_np, preds_bin) avg_precision = average_precision_score(labels_np, all_preds) real_acc = accuracy_score(labels_np[labels_np == 0], preds_bin[labels_np == 0]) fake_acc = accuracy_score(labels_np[labels_np == 1], preds_bin[labels_np == 1]) print(f"[Validation] Accuracy: {overall_acc:.4f}, Avg Precision: {avg_precision:.4f}, Real Acc: {real_acc:.4f}, Fake Acc: {fake_acc:.4f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection") parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory") parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory") parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory") parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory") parser.add_argument("--batch_size", type=int, default=256, help="Batch size") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") parser.add_argument("--noise_std", type=float, default=0.1, help="added noise to DINO embeddings") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')") parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs") parser.add_argument("--num_layers", type=int, default=4, help="number of layers in classifier head (1 to 5)") parser.add_argument("--save_model_path", type=str, default='/media/external_16TB_1/amirtaha_amanzadi/DINO', help="Path to save the model") parser.add_argument("--dino_variant", type=str, help="DINO variant", default="dinov2_vitl14") parser.add_argument("--clip_variant", type=str, help="CLIP variant", default="ViT-L-14") parser.add_argument("--clip_pretrained_dataset", type=str, help="CLIP pretrained dataset", default="datacomp_xl_s13b_b90k") parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training") parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training") parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training") parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training") parser.add_argument("--cosine_branch", action="store_true", help="Whether to use cosine branch during training") parser.add_argument("--jpeg", type=int, nargs='*', default=None, help="Apply JPEG compression with given quality factors [95, 90, 75, 50]") parser.add_argument("--blur", type=float, nargs='*', default=None, help="Apply Gaussian blur with given sigma values [1, 2, 3]") parser.add_argument("--aug_prob", type=float, default=0, help="Probability to apply JPEG/blur augmentations during training") parser.add_argument("--eval", action="store_true", help="Run validation only") parser.add_argument("--model_path", type=str, help="Path to the saved model for evaluation") args = parser.parse_args() print(args) # Set GPUs # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.eval: validate_model(args) else: date_now = datetime.datetime.now() args.save_model_path = args.save_model_path + f"/Log_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" os.makedirs(args.save_model_path, exist_ok=True) # Save args to file in log directory args_file_path = os.path.join(args.save_model_path, "args.txt") save_args_to_file(args, args_file_path) train_model(args)