You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_concat_ddp.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import argparse
  2. import clip
  3. import open_clip
  4. import datetime
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. import torch.distributed as dist
  10. from torch.utils.data import DataLoader, Dataset
  11. from torch.utils.data.distributed import DistributedSampler
  12. from torchvision import transforms
  13. from PIL import Image
  14. import os
  15. from tqdm import tqdm
  16. from sklearn.metrics import accuracy_score, average_precision_score
  17. import numpy as np
  18. import warnings
  19. warnings.filterwarnings('ignore')
  20. from dataset import ImageDataset
  21. from model import HybridModel, HybridModel_FeatUp # Your two model classes
  22. # Utility function for calculating metrics
  23. def calculate_metrics(predictions, labels):
  24. predictions = (predictions > 0.5).astype(int)
  25. accuracy = accuracy_score(labels, predictions)
  26. avg_precision = average_precision_score(labels, predictions)
  27. return accuracy, avg_precision
  28. # Training function using DDP
  29. def train_model(args, dino_variant="dinov2_vitl14", clip_variant=("ViT-L-14", "datacomp_xl_s13b_b90k")):
  30. # Initialize distributed training
  31. dist.init_process_group(backend="nccl")
  32. local_rank = int(os.environ["LOCAL_RANK"])
  33. device = torch.device("cuda", local_rank)
  34. print(f"Local Rank: {local_rank}, Device: {device}")
  35. # Model initialization (two branches based on args.featup)
  36. if args.featup:
  37. # For featup branch, we only use CLIP with featup
  38. clip_model, _, _ = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1])
  39. # Freeze CLIP parameters (or finetune if desired)
  40. for name, param in clip_model.named_parameters():
  41. param.requires_grad = False
  42. if args.finetune_clip:
  43. for name, param in clip_model.named_parameters():
  44. if 'transformer.resblocks.21' in name:
  45. param.requires_grad = True
  46. print("Unfrozen layers in CLIP model:")
  47. for name, param in clip_model.named_parameters():
  48. if param.requires_grad:
  49. print(name)
  50. # In this branch the feature_dim is computed as (384 * 256 * 256) + clip_dim.
  51. # (Adjust the numbers if needed for your setting.)
  52. feature_dim = (384 * 256 * 256) + {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]]
  53. hybrid_model = HybridModel_FeatUp(clip_model, feature_dim, args.mixstyle)
  54. else:
  55. # Load DINOv2 and CLIP models
  56. dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
  57. clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_variant[0], pretrained=clip_variant[1])
  58. # Freeze CLIP parameters (or finetune if desired)
  59. for name, param in clip_model.named_parameters():
  60. param.requires_grad = False
  61. if args.finetune_clip:
  62. for name, param in clip_model.named_parameters():
  63. if 'transformer.resblocks.21' in name:
  64. param.requires_grad = True
  65. print("Unfrozen layers in CLIP model:")
  66. for name, param in clip_model.named_parameters():
  67. if param.requires_grad:
  68. print(name)
  69. # Freeze DINO parameters (or finetune if desired)
  70. for name, param in dino_model.named_parameters():
  71. param.requires_grad = False
  72. if args.finetune_dino:
  73. for name, param in dino_model.named_parameters():
  74. if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name:
  75. param.requires_grad = True
  76. print("Unfrozen layers in DINOv2 model:")
  77. for name, param in dino_model.named_parameters():
  78. if param.requires_grad:
  79. print(name)
  80. dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[dino_variant]
  81. clip_dim = {"ViT-L-14":768, "ViT-L-14-quickgelu":768, "ViT-H-14-quickgelu":1280}[clip_variant[0]]
  82. feature_dim = dino_dim + clip_dim
  83. hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.mixstyle)
  84. print("Model Initialization done ...")
  85. # Wrap the model with DDP
  86. hybrid_model.to(device)
  87. hybrid_model = nn.parallel.DistributedDataParallel(hybrid_model, device_ids=[local_rank], output_device=local_rank)
  88. # Print trainable parameters (only printed by rank 0)
  89. if dist.get_rank() == 0:
  90. print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad))
  91. # Data Preparation
  92. transform = transforms.Compose([
  93. transforms.Resize((224, 224)),
  94. transforms.RandomHorizontalFlip(p=0.5),
  95. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
  96. transforms.RandomRotation(degrees=10),
  97. transforms.ToTensor(),
  98. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  99. ])
  100. train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform)
  101. train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform)
  102. test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
  103. test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
  104. train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset])
  105. test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
  106. # Use DistributedSampler for multi-GPU training
  107. train_sampler = DistributedSampler(train_dataset)
  108. test_sampler = DistributedSampler(test_dataset, shuffle=False)
  109. train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=8, pin_memory=True)
  110. test_loader = DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, num_workers=8, pin_memory=True)
  111. # Optimizer and loss function
  112. optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr)
  113. criterion = nn.BCEWithLogitsLoss()
  114. # Training loop
  115. for epoch in range(args.epochs):
  116. train_sampler.set_epoch(epoch)
  117. hybrid_model.train()
  118. running_loss = 0
  119. for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch", disable=(dist.get_rank() != 0)):
  120. images, labels = images.to(device), labels.float().to(device)
  121. noisy_images = images + torch.randn_like(images) * args.noise_std
  122. optimizer.zero_grad()
  123. outputs = hybrid_model(images, noisy_images).squeeze()
  124. loss = criterion(outputs, labels)
  125. loss.backward()
  126. optimizer.step()
  127. running_loss += loss.item()
  128. if dist.get_rank() == 0:
  129. print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}")
  130. # Evaluation
  131. hybrid_model.eval()
  132. all_preds, all_labels = [], []
  133. with torch.no_grad():
  134. for images, labels in test_loader:
  135. images, labels = images.to(device), labels.float().to(device)
  136. noisy_images = images + torch.randn_like(images) * args.noise_std
  137. outputs = torch.sigmoid(hybrid_model(images, noisy_images).squeeze())
  138. all_preds.extend(outputs.cpu().numpy())
  139. all_labels.extend(labels.cpu().numpy())
  140. # Gather evaluation metrics only on rank 0
  141. if dist.get_rank() == 0:
  142. accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels))
  143. print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}")
  144. # Save model every args.save_every epochs (rank 0 only)
  145. if (epoch + 1) % args.save_every == 0:
  146. save_path = os.path.join(args.save_model_path, f"{args.save_model_name}_ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth")
  147. torch.save(hybrid_model.state_dict(), save_path)
  148. print(f"Model saved at {save_path}")
  149. # Clean up distributed training
  150. dist.destroy_process_group()
  151. def save_args_to_file(args, save_path):
  152. """Save the parser arguments to a text file."""
  153. with open(save_path, "w") as f:
  154. for arg, value in vars(args).items():
  155. f.write(f"{arg}: {value}\n")
  156. if __name__ == "__main__":
  157. parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection")
  158. parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory")
  159. parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory")
  160. parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory")
  161. parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory")
  162. parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
  163. parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
  164. parser.add_argument("--noise_std", type=float, default=0.1, help="Added noise to DINO embeddings")
  165. parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
  166. parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')")
  167. parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs")
  168. parser.add_argument("--save_model_path", type=str, default='./saved_models', help="Path to save the model")
  169. parser.add_argument("--save_model_name", type=str, required=True, help="Base name for saved models")
  170. parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training")
  171. parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training")
  172. parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training")
  173. parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training")
  174. args = parser.parse_args()
  175. date_now = datetime.datetime.now()
  176. args.save_model_path = os.path.join(args.save_model_path, f"Log_{date_now.strftime('%Y-%m-%d_%H-%M-%S')}")
  177. os.makedirs(args.save_model_path, exist_ok=True)
  178. # Save args to file in log directory
  179. args_file_path = os.path.join(args.save_model_path, "args.txt")
  180. save_args_to_file(args, args_file_path)
  181. # (Do not set CUDA_VISIBLE_DEVICES here; torchrun manages devices based on LOCAL_RANK)
  182. train_model(args)