import argparse import clip import open_clip import datetime import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.distributed as dist from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler from torchvision import transforms from PIL import Image import os from tqdm import tqdm from sklearn.metrics import accuracy_score, average_precision_score import numpy as np import warnings warnings.filterwarnings('ignore') from dataset import ImageDataset from model import HybridModel, HybridModel_FeatUp # Your two model classes # Utility function for calculating metrics def calculate_metrics(predictions, labels): predictions = (predictions > 0.5).astype(int) accuracy = accuracy_score(labels, predictions) avg_precision = average_precision_score(labels, predictions) return accuracy, avg_precision # Training function using DDP def train_model(args, dino_variant="dinov2_vitl14", clip_variant=("ViT-L-14", "datacomp_xl_s13b_b90k")): # Initialize distributed training dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device("cuda", local_rank) print(f"Local Rank: {local_rank}, Device: {device}") # Model initialization (two branches based on args.featup) if args.featup: # For featup branch, we only use CLIP with featup clip_model, _, _ = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1]) # Freeze CLIP parameters (or finetune if desired) for name, param in clip_model.named_parameters(): param.requires_grad = False if args.finetune_clip: for name, param in clip_model.named_parameters(): if 'transformer.resblocks.21' in name: param.requires_grad = True print("Unfrozen layers in CLIP model:") for name, param in clip_model.named_parameters(): if param.requires_grad: print(name) # In this branch the feature_dim is computed as (384 * 256 * 256) + clip_dim. # (Adjust the numbers if needed for your setting.) feature_dim = (384 * 256 * 256) + {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]] hybrid_model = HybridModel_FeatUp(clip_model, feature_dim, args.mixstyle) else: # Load DINOv2 and CLIP models dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant) clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1]) # Freeze CLIP parameters (or finetune if desired) for name, param in clip_model.named_parameters(): param.requires_grad = False if args.finetune_clip: for name, param in clip_model.named_parameters(): if 'transformer.resblocks.21' in name: param.requires_grad = True print("Unfrozen layers in CLIP model:") for name, param in clip_model.named_parameters(): if param.requires_grad: print(name) # Freeze DINO parameters (or finetune if desired) for name, param in dino_model.named_parameters(): param.requires_grad = False if args.finetune_dino: for name, param in dino_model.named_parameters(): if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name: param.requires_grad = True print("Unfrozen layers in DINOv2 model:") for name, param in dino_model.named_parameters(): if param.requires_grad: print(name) dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[dino_variant] clip_dim = {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]] feature_dim = dino_dim + clip_dim hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.mixstyle) print("Model Initialization done ...") # Wrap the model with DDP hybrid_model.to(device) hybrid_model = nn.parallel.DistributedDataParallel(hybrid_model, device_ids=[local_rank], output_device=local_rank) # Print trainable parameters (only printed by rank 0) if dist.get_rank() == 0: print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad)) # Data Preparation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.RandomRotation(degrees=10), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform) train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform) test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform) test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform) train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset]) test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset]) # Use DistributedSampler for multi-GPU training train_sampler = DistributedSampler(train_dataset) test_sampler = DistributedSampler(test_dataset, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=8, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, num_workers=8, pin_memory=True) # Optimizer and loss function optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr) criterion = nn.BCEWithLogitsLoss() # Training loop for epoch in range(args.epochs): train_sampler.set_epoch(epoch) hybrid_model.train() running_loss = 0 for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch", disable=(dist.get_rank() != 0)): images, labels = images.to(device), labels.float().to(device) noisy_images = images + torch.randn_like(images) * args.noise_std optimizer.zero_grad() outputs = hybrid_model(images, noisy_images).squeeze() loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if dist.get_rank() == 0: print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}") # Evaluation hybrid_model.eval() all_preds, all_labels = [], [] with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.float().to(device) noisy_images = images + torch.randn_like(images) * args.noise_std outputs = torch.sigmoid(hybrid_model(images, noisy_images).squeeze()) all_preds.extend(outputs.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # Gather evaluation metrics only on rank 0 if dist.get_rank() == 0: accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels)) print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}") # Save model every args.save_every epochs (rank 0 only) if (epoch + 1) % args.save_every == 0: save_path = os.path.join(args.save_model_path, f"{args.save_model_name}_ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth") torch.save(hybrid_model.state_dict(), save_path) print(f"Model saved at {save_path}") # Clean up distributed training dist.destroy_process_group() def save_args_to_file(args, save_path): """Save the parser arguments to a text file.""" with open(save_path, "w") as f: for arg, value in vars(args).items(): f.write(f"{arg}: {value}\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection") parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory") parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory") parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory") parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory") parser.add_argument("--batch_size", type=int, default=256, help="Batch size") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") parser.add_argument("--noise_std", type=float, default=0.1, help="Added noise to DINO embeddings") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')") parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs") parser.add_argument("--save_model_path", type=str, default='./saved_models', help="Path to save the model") parser.add_argument("--save_model_name", type=str, required=True, help="Base name for saved models") parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training") parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training") parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training") parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training") args = parser.parse_args() date_now = datetime.datetime.now() args.save_model_path = os.path.join(args.save_model_path, f"Log_{date_now.strftime('%Y-%m-%d_%H-%M-%S')}") os.makedirs(args.save_model_path, exist_ok=True) # Save args to file in log directory args_file_path = os.path.join(args.save_model_path, "args.txt") save_args_to_file(args, args_file_path) # (Do not set CUDA_VISIBLE_DEVICES here; torchrun manages devices based on LOCAL_RANK) train_model(args)