You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_concat.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import argparse
  2. import clip
  3. import open_clip
  4. import datetime
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. from torch.utils.data import DataLoader, Dataset
  10. from torchvision import transforms
  11. from PIL import Image
  12. import os
  13. from tqdm import tqdm
  14. from sklearn.metrics import accuracy_score, average_precision_score
  15. import numpy as np
  16. import warnings
  17. warnings.filterwarnings('ignore')
  18. from dataset import ImageDataset
  19. from model import HybridModel
  20. # Utility function for calculating metrics
  21. def calculate_metrics(predictions, labels):
  22. predictions = (predictions > 0.5).astype(int)
  23. accuracy = accuracy_score(labels, predictions)
  24. avg_precision = average_precision_score(labels, predictions)
  25. return accuracy, avg_precision
  26. # Training function
  27. def train_model(args): #{ViT-L-14-quickgelu, dfn2b} {ViT-H-14-quickgelu, dfn5b}
  28. # Set device
  29. device_ids = [int(gpu) for gpu in args.gpu.split(",")]
  30. device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
  31. print(device)
  32. # Load CLIP and DINOv2 model
  33. dino_model = torch.hub.load('facebookresearch/dinov2',args.dino_variant)
  34. # clip_model, _ = clip.load(clip_variant, device)
  35. clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset)
  36. # Freeze/Finetune CLIP
  37. for name, param in clip_model.named_parameters():
  38. param.requires_grad = False
  39. if args.finetune_clip:
  40. for name, param in clip_model.named_parameters():
  41. if 'transformer.resblocks.21' in name: # Example: unfreeze last few blocks
  42. param.requires_grad = True
  43. print("Unfrozen layers in CLIP model:")
  44. for name, param in clip_model.named_parameters():
  45. if param.requires_grad:
  46. print(name)
  47. # Freeze/Finetune DINO
  48. for name, param in dino_model.named_parameters():
  49. param.requires_grad = False
  50. if args.finetune_dino:
  51. for name, param in dino_model.named_parameters():
  52. if 'blocks.23.attn.proj' in name or 'blocks.23.mlp.fc2' in name:
  53. param.requires_grad = True
  54. print("Unfrozen layers in DINOv2 model:")
  55. for name, param in dino_model.named_parameters():
  56. if param.requires_grad:
  57. print(name)
  58. # Model initialization
  59. dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant]
  60. clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant]
  61. print("clip_dim:", clip_dim, "clip_dim:", dino_dim)
  62. feature_dim = dino_dim + clip_dim
  63. # feature_dim = clip_dim
  64. hybrid_model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device)
  65. # Wrap model for multi-GPU if applicable
  66. if len(device_ids) > 1:
  67. print("Using Parallel NN (multiple GPUs)")
  68. hybrid_model = nn.DataParallel(hybrid_model, device_ids=device_ids)
  69. # Print trainable parameters
  70. print("Trainable parameters:", sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad))
  71. # Data Preparation
  72. transform = transforms.Compose([
  73. transforms.Resize((224, 224)),
  74. transforms.RandomHorizontalFlip(p=0.5), # Random Flip
  75. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Color Jitter
  76. transforms.RandomRotation(degrees=10), # Random Rotation
  77. transforms.ToTensor(),
  78. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  79. ])
  80. # train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform)
  81. # train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform)
  82. # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
  83. # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
  84. train_fake_dataset = ImageDataset(args.train_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob)
  85. train_real_dataset = ImageDataset(args.train_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob)
  86. test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=True, aug_prob=args.aug_prob)
  87. test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=True, aug_prob=args.aug_prob)
  88. train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset])
  89. test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
  90. train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
  91. test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
  92. # Optimizer and loss function
  93. optimizer = optim.Adam([p for p in hybrid_model.parameters() if p.requires_grad], lr=args.lr)
  94. criterion = nn.BCEWithLogitsLoss()
  95. # Training loop
  96. for epoch in range(args.epochs):
  97. hybrid_model.train()
  98. running_loss = 0
  99. for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch"):
  100. images, labels = images.to(device), labels.float().to(device)
  101. # Create noise-perturbed images
  102. # noisy_images = images + torch.randn_like(images) * args.noise_std
  103. optimizer.zero_grad()
  104. outputs = hybrid_model(images).squeeze()
  105. loss = criterion(outputs, labels)
  106. loss.backward()
  107. optimizer.step()
  108. running_loss += loss.item()
  109. print(f"Epoch {epoch + 1}: Loss = {running_loss / len(train_loader):.4f}")
  110. # Evaluation
  111. hybrid_model.eval()
  112. all_preds, all_labels = [], []
  113. with torch.no_grad():
  114. for images, labels in test_loader:
  115. images, labels = images.to(device), labels.float().to(device)
  116. # Create noise-perturbed images
  117. # noisy_images = images + torch.randn_like(images) * args.noise_std
  118. outputs = torch.sigmoid(hybrid_model(images).squeeze())
  119. all_preds.extend(outputs.cpu().numpy())
  120. all_labels.extend(labels.cpu().numpy())
  121. accuracy, avg_precision = calculate_metrics(np.array(all_preds), np.array(all_labels))
  122. print(f"Epoch {epoch + 1}: Test Accuracy = {accuracy:.4f}, Average Precision = {avg_precision:.4f}")
  123. # Save model
  124. if (epoch + 1) % args.save_every == 0:
  125. save_path = os.path.join(args.save_model_path, f"ep{epoch + 1}_acc_{accuracy:.4f}_ap_{avg_precision:.4f}.pth")
  126. torch.save(hybrid_model.state_dict(), save_path)
  127. print(f"Model saved at {save_path}")
  128. def save_args_to_file(args, save_path):
  129. """Save the parser arguments to a text file."""
  130. with open(save_path, "w") as f:
  131. for arg, value in vars(args).items():
  132. f.write(f"{arg}: {value}\n")
  133. def validate_model(args):
  134. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  135. # Load models
  136. dino_model = torch.hub.load('facebookresearch/dinov2', args.dino_variant)
  137. clip_model, _, preprocess = open_clip.create_model_and_transforms(args.clip_variant, pretrained=args.clip_pretrained_dataset)
  138. # Init model wrapper
  139. dino_dim = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024}[args.dino_variant]
  140. clip_dim = {"ViT-L-14": 768, "ViT-L-14-quickgelu": 768, "ViT-H-14-quickgelu": 1024}[args.clip_variant]
  141. feature_dim = dino_dim + clip_dim
  142. # feature_dim = clip_dim
  143. model = HybridModel(dino_model, clip_model, feature_dim, args.featup, args.mixstyle, args.cosine_branch, args.num_layers).to(device)
  144. model.load_state_dict(torch.load(args.model_path, map_location=device))
  145. model.eval()
  146. transform = transforms.Compose([
  147. transforms.Resize((224, 224)),
  148. transforms.ToTensor(),
  149. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  150. ])
  151. # test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform)
  152. # test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform)
  153. test_fake_dataset = ImageDataset(args.test_fake_dir, label=1, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur)
  154. test_real_dataset = ImageDataset(args.test_real_dir, label=0, transform=transform, is_train=False, jpeg_qf=args.jpeg, gaussian_blur=args.blur)
  155. test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])
  156. test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
  157. all_preds, all_labels = [], []
  158. with torch.no_grad():
  159. for images, labels in tqdm(test_loader, desc="Evaluating", unit="batch"):
  160. images, labels = images.to(device), labels.float().to(device)
  161. # noisy_images = images + torch.randn_like(images) * args.noise_std
  162. outputs = torch.sigmoid(model(images).squeeze())
  163. all_preds.extend(outputs.cpu().numpy())
  164. all_labels.extend(labels.cpu().numpy())
  165. preds_bin = (np.array(all_preds) > 0.5).astype(int)
  166. labels_np = np.array(all_labels)
  167. overall_acc = accuracy_score(labels_np, preds_bin)
  168. avg_precision = average_precision_score(labels_np, all_preds)
  169. real_acc = accuracy_score(labels_np[labels_np == 0], preds_bin[labels_np == 0])
  170. fake_acc = accuracy_score(labels_np[labels_np == 1], preds_bin[labels_np == 1])
  171. print(f"[Validation] Accuracy: {overall_acc:.4f}, Avg Precision: {avg_precision:.4f}, Real Acc: {real_acc:.4f}, Fake Acc: {fake_acc:.4f}")
  172. if __name__ == "__main__":
  173. parser = argparse.ArgumentParser(description="Train Hybrid Model for Fake Image Detection")
  174. parser.add_argument("--train_fake_dir", type=str, required=True, help="Path to training fake images directory")
  175. parser.add_argument("--train_real_dir", type=str, required=True, help="Path to training real images directory")
  176. parser.add_argument("--test_fake_dir", type=str, required=True, help="Path to testing fake images directory")
  177. parser.add_argument("--test_real_dir", type=str, required=True, help="Path to testing real images directory")
  178. parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
  179. parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
  180. parser.add_argument("--noise_std", type=float, default=0.1, help="added noise to DINO embeddings")
  181. parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
  182. parser.add_argument("--gpu", type=str, default="0", help="GPU device to use (e.g., '0', '0,1')")
  183. parser.add_argument("--save_every", type=int, default=1, help="Save model every n epochs")
  184. parser.add_argument("--num_layers", type=int, default=4, help="number of layers in classifier head (1 to 5)")
  185. parser.add_argument("--save_model_path", type=str, default='/media/external_16TB_1/amirtaha_amanzadi/DINO', help="Path to save the model")
  186. parser.add_argument("--dino_variant", type=str, help="DINO variant", default="dinov2_vitl14")
  187. parser.add_argument("--clip_variant", type=str, help="CLIP variant", default="ViT-L-14")
  188. parser.add_argument("--clip_pretrained_dataset", type=str, help="CLIP pretrained dataset", default="datacomp_xl_s13b_b90k")
  189. parser.add_argument("--finetune_dino", action="store_true", help="Whether to finetune DINO model during training")
  190. parser.add_argument("--finetune_clip", action="store_true", help="Whether to finetune CLIP model during training")
  191. parser.add_argument("--featup", action="store_true", help="Whether to use FeatUpscaling for features during training")
  192. parser.add_argument("--mixstyle", action="store_true", help="Whether to apply MixStyle for domain generalization during training")
  193. parser.add_argument("--cosine_branch", action="store_true", help="Whether to use cosine branch during training")
  194. parser.add_argument("--jpeg", type=int, nargs='*', default=None, help="Apply JPEG compression with given quality factors [95, 90, 75, 50]")
  195. parser.add_argument("--blur", type=float, nargs='*', default=None, help="Apply Gaussian blur with given sigma values [1, 2, 3]")
  196. parser.add_argument("--aug_prob", type=float, default=0, help="Probability to apply JPEG/blur augmentations during training")
  197. parser.add_argument("--eval", action="store_true", help="Run validation only")
  198. parser.add_argument("--model_path", type=str, help="Path to the saved model for evaluation")
  199. args = parser.parse_args()
  200. print(args)
  201. # Set GPUs
  202. # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
  203. if args.eval:
  204. validate_model(args)
  205. else:
  206. date_now = datetime.datetime.now()
  207. args.save_model_path = args.save_model_path + f"/Log_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
  208. os.makedirs(args.save_model_path, exist_ok=True)
  209. # Save args to file in log directory
  210. args_file_path = os.path.join(args.save_model_path, "args.txt")
  211. save_args_to_file(args, args_file_path)
  212. train_model(args)