| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- import argparse
- import clip
- import open_clip
- import datetime
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- import torch.distributed as dist
- from torch.utils.data import DataLoader, Dataset
- from torch.utils.data.distributed import DistributedSampler
- from torchvision import transforms
- from PIL import Image
- import os
- from tqdm import tqdm
- from sklearn.metrics import accuracy_score, average_precision_score
- import numpy as np
- import warnings
- warnings.filterwarnings('ignore')
- from dataset import ImageDataset
- from model import HybridModel, HybridModel_FeatUp # Your two model classes
-
- # Utility function for calculating metrics
- def calculate_metrics(predictions, labels):
- predictions = (predictions > 0.5).astype(int)
- accuracy = accuracy_score(labels, predictions)
- avg_precision = average_precision_score(labels, predictions)
- return accuracy, avg_precision
-
- # Training function using DDP
- def train_model(args, dino_variant="dinov2_vitl14", clip_variant=("ViT-L-14", "datacomp_xl_s13b_b90k")):
- # Initialize distributed training
- dist.init_process_group(backend="nccl")
- local_rank = int(os.environ["LOCAL_RANK"])
- device = torch.device("cuda", local_rank)
- print(f"Local Rank: {local_rank}, Device: {device}")
-
- # Model initialization (two branches based on args.featup)
- if args.featup:
- # For featup branch, we only use CLIP with featup
- clip_model, _, _ = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1])
- # Freeze CLIP parameters (or finetune if desired)
- for name, param in clip_model.named_parameters():
- param.requires_grad = False
- if args.finetune_clip:
- for name, param in clip_model.named_parameters():
- if 'transformer.resblocks.21' in name:
- param.requires_grad = True
- print("Unfrozen layers in CLIP model:")
- for name, param in clip_model.named_parameters():
- if param.requires_grad:
- print(name)
- # In this branch the feature_dim is computed as (384 * 256 * 256) + clip_dim.
- # (Adjust the numbers if needed for your setting.)
- feature_dim = (384 * 256 * 256) + {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]]
- hybrid_model = HybridModel_FeatUp(clip_model, feature_dim, args.mixstyle)
- else:
- # Load DINOv2 and CLIP models
- dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
- clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1])
- # Freeze CLIP parameters (or finetune if desired)
- for name, param in clip_model.named_parameters():
- param.requires_grad = False
- if args.finetune_clip:
- for name, param in clip_model.named_parameters():
- if 'transformer.resblocks.21' in name:
- param.requires_grad = True
- print("Unfrozen layers in CLIP model:")
- for name, param in clip_model.named_parameters():
- if param.requires_grad:
- print(name)
- # Freeze DINO parameters (or finetune if desired)
- for name, param in dino_model.named_parameters():
- param.requires_grad = False
- if args.finetune_dino:
- for name, param in dino_model.named_parameters():
- if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name:
- param.requires_grad = True
- print("Unfrozen layers in DINOv2 model:")
- for name, param in dino_model.named_parameters():
- if param.requires_grad:
- print(name)
- dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[dino_variant]
- clip_dim = {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]]
- feature_dim = dino_dim + clip_dim
- hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.mixstyle)
-
- print("Model Initialization done ...")
-
- # Wrap the model with DDP
- hybrid_model.to(device)
- hybrid_model = nn.parallel.DistributedDataParallel(hybrid_model, device_ids=[local_rank], output_device=local_rank)
-
- # Print trainable parameters (only printed by rank 0)
- if dist.get_rank() == 0:
- print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad))
-
- # Data Preparation
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.RandomHorizontalFlip(p=0.5),
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
- transforms.RandomRotation(degrees=10),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform)
- train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform)
- test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
- test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
-
- train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset])
- test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
-
- # Use DistributedSampler for multi-GPU training
- train_sampler = DistributedSampler(train_dataset)
- test_sampler = DistributedSampler(test_dataset, shuffle=False)
-
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=8, pin_memory=True)
- test_loader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, num_workers=8, pin_memory=True)
-
- # Optimizer and loss function
- optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr)
- criterion = nn.BCEWithLogitsLoss()
-
- # Training loop
- for epoch in range(args.epochs):
- train_sampler.set_epoch(epoch)
- hybrid_model.train()
- running_loss = 0
-
- for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch", disable=(dist.get_rank() != 0)):
- images, labels = images.to(device), labels.float().to(device)
- noisy_images = images + torch.randn_like(images) * args.noise_std
-
- optimizer.zero_grad()
- outputs = hybrid_model(images, noisy_images).squeeze()
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
-
- if dist.get_rank() == 0:
- print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}")
-
- # Evaluation
- hybrid_model.eval()
- all_preds, all_labels = [], []
-
- with torch.no_grad():
- for images, labels in test_loader:
- images, labels = images.to(device), labels.float().to(device)
- noisy_images = images + torch.randn_like(images) * args.noise_std
- outputs = torch.sigmoid(hybrid_model(images, noisy_images).squeeze())
- all_preds.extend(outputs.cpu().numpy())
- all_labels.extend(labels.cpu().numpy())
-
- # Gather evaluation metrics only on rank 0
- if dist.get_rank() == 0:
- accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels))
- print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}")
-
- # Save model every args.save_every epochs (rank 0 only)
- if (epoch + 1) % args.save_every == 0:
- save_path = os.path.join(args.save_model_path, f"{args.save_model_name}_ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth")
- torch.save(hybrid_model.state_dict(), save_path)
- print(f"Model saved at {save_path}")
-
- # Clean up distributed training
- dist.destroy_process_group()
-
-
- def save_args_to_file(args, save_path):
- """Save the parser arguments to a text file."""
- with open(save_path, "w") as f:
- for arg, value in vars(args).items():
- f.write(f"{arg}: {value}\n")
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection")
- parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory")
- parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory")
- parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory")
- parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory")
- parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
- parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
- parser.add_argument("--noise_std", type=float, default=0.1, help="Added noise to DINO embeddings")
- parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
- parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')")
- parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs")
- parser.add_argument("--save_model_path", type=str, default='./saved_models', help="Path to save the model")
- parser.add_argument("--save_model_name", type=str, required=True, help="Base name for saved models")
- parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training")
- parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training")
- parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training")
- parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training")
-
- args = parser.parse_args()
-
- date_now = datetime.datetime.now()
- args.save_model_path = os.path.join(args.save_model_path, f"Log_{date_now.strftime('%Y-%m-%d_%H-%M-%S')}")
- os.makedirs(args.save_model_path, exist_ok=True)
-
- # Save args to file in log directory
- args_file_path = os.path.join(args.save_model_path, "args.txt")
- save_args_to_file(args, args_file_path)
-
- # (Do not set CUDA_VISIBLE_DEVICES here; torchrun manages devices based on LOCAL_RANK)
-
- train_model(args)
|