You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tsne.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import os
  2. import torch
  3. import clip
  4. import pandas as pd
  5. import plotly.express as px
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from PIL import Image
  9. from torchvision import transforms
  10. from sklearn.manifold import TSNE
  11. from tqdm import tqdm
  12. from sklearn.decomposition import PCA
  13. import warnings
  14. warnings.filterwarnings('ignore')
  15. # Ensure models are in evaluation mode and use GPU if available
  16. device = "cuda" if torch.cuda.is_available() else "cpu"
  17. print(device)
  18. # Load DINOv2 model and CLIP model
  19. dino_variant = "dinov2_vitl14" # Change to the desired DINO variant
  20. dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
  21. dino_model.eval().to(device)
  22. clip_model, _ = clip.load("ViT-L/14", device=device, jit=False)
  23. clip_model.eval().to(device)
  24. # Define preprocessing transformations
  25. transform = transforms.Compose([
  26. transforms.Resize((224, 224)),
  27. transforms.ToTensor(),
  28. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  29. ])
  30. def load_finetuned_dino(model_variant, state_dict_path):
  31. # Load the base DINO model
  32. model = torch.hub.load('facebookresearch/dinov2', model_variant)
  33. model.eval().to(device)
  34. state_dict = torch.load(state_dict_path, map_location=device)
  35. model.load_state_dict(state_dict, strict=False)
  36. print(f"Loaded fine-tuned weights from {state_dict_path}")
  37. return model
  38. # Add Gaussian noise
  39. def add_noise(image, std_dev=0.1):
  40. noise = torch.randn_like(image) * std_dev
  41. return image + noise
  42. # Extract embeddings using the DINO model
  43. def extract_dino_embeddings(image_path, model, add_noise_flag=False, noise_std=0.1):
  44. image = Image.open(image_path).convert("RGB")
  45. image_tensor = transform(image).unsqueeze(0).to(device)
  46. with torch.no_grad():
  47. if add_noise_flag:
  48. noisy_image_tensor = add_noise(image_tensor, std_dev=noise_std)
  49. embedding = model(noisy_image_tensor).squeeze()
  50. else:
  51. embedding = model(image_tensor).squeeze()
  52. return embedding.cpu().numpy()
  53. # Extract embeddings using the CLIP model
  54. def extract_clip_embeddings(image_path, model):
  55. image = Image.open(image_path).convert("RGB")
  56. image_tensor = transform(image).unsqueeze(0).to(device)
  57. with torch.no_grad():
  58. embedding = model.encode_image(image_tensor).squeeze()
  59. return embedding.cpu().numpy()
  60. # Extract embeddings for all images in a folder
  61. def extract_all_embeddings(folder_path, model, method, add_noise_flag=False, noise_std=0.1):
  62. embeddings = []
  63. for image_name in tqdm(os.listdir(folder_path), desc=f"Processing {folder_path}"):
  64. image_path = os.path.join(folder_path, image_name)
  65. if os.path.isfile(image_path):
  66. if method == "dino":
  67. embedding = extract_dino_embeddings(image_path, model, add_noise_flag, noise_std)
  68. elif method == "clip":
  69. embedding = extract_clip_embeddings(image_path, model)
  70. embeddings.append(embedding)
  71. return np.array(embeddings)
  72. # Visualize embeddings using t-SNE
  73. def visualize_embeddings_2d(real_embeddings, fake_embeddings, title, out_path, perplexity=40):
  74. all_embeddings = np.vstack([real_embeddings, fake_embeddings])
  75. labels = ["Real"] * len(real_embeddings) + ["Fake"] * len(fake_embeddings)
  76. tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
  77. tsne_results = tsne.fit_transform(all_embeddings)
  78. plt.figure(figsize=(10, 8))
  79. for label, color in zip(["Real", "Fake"], ["blue", "red"]):
  80. indices = [i for i, l in enumerate(labels) if l == label]
  81. plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], label=label, alpha=0.7, s=50, c=color)
  82. plt.title(title)
  83. plt.legend()
  84. plt.savefig(out_path)
  85. plt.close()
  86. def visualize_embeddings_3d(real_embeddings, fake_embeddings, title, out_path, perplexity=40):
  87. all_embeddings = np.vstack([real_embeddings, fake_embeddings])
  88. labels = ["Real"] * len(real_embeddings) + ["Fake"] * len(fake_embeddings)
  89. tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity)
  90. tsne_results = tsne.fit_transform(all_embeddings)
  91. # Create a DataFrame for Plotly
  92. tsne_df = pd.DataFrame(tsne_results, columns=['x', 'y', 'z'])
  93. tsne_df['label'] = labels
  94. # Create the 3D scatter plot
  95. fig = px.scatter_3d(tsne_df, x='x', y='y', z='z', color='label', title=title,
  96. color_discrete_map={"Real": "blue", "Fake": "red"},
  97. labels={'label': 'Image Type'})
  98. fig.update_traces(marker=dict(size=5, opacity=0.8), selector=dict(mode='markers'))
  99. # Save the figure as an interactive HTML file
  100. fig.write_html(out_path)
  101. print(f"3D plot saved as {out_path}")
  102. fig.data = None
  103. fig.layout = None
  104. def visualize_embeddings_4dirs_3d(real_train_folder, fake_train_folder, real_test_folder, fake_test_folder, title, out_path, method, perplexity=50, finetuned_dino_path=None):
  105. # Use fine-tuned DINO if a path is provided
  106. if finetuned_dino_path:
  107. print("Using fine-tuned DINO model.")
  108. dino_model = load_finetuned_dino(dino_variant, finetuned_dino_path)
  109. dino_model.eval().to(device)
  110. else:
  111. dino_variant = "dinov2_vitl14" # Change to the desired DINO variant
  112. dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
  113. dino_model.eval().to(device)
  114. print("Using original DINO model.")
  115. if 'clip' in method:
  116. # Extract embeddings for all folders CLIP
  117. real_train_embeddings = extract_all_embeddings(real_train_folder, clip_model, method="clip")
  118. fake_train_embeddings = extract_all_embeddings(fake_train_folder, clip_model, method="clip")
  119. real_test_embeddings = extract_all_embeddings(real_test_folder, clip_model, method="clip")
  120. fake_test_embeddings = extract_all_embeddings(fake_test_folder, clip_model, method="clip")
  121. if 'dino' in method:
  122. # Extract embeddings for all folders DINO
  123. real_train_embeddings = extract_all_embeddings(real_train_folder, dino_model, method="dino")
  124. fake_train_embeddings = extract_all_embeddings(fake_train_folder, dino_model, method="dino")
  125. real_test_embeddings = extract_all_embeddings(real_test_folder, dino_model, method="dino")
  126. fake_test_embeddings = extract_all_embeddings(fake_test_folder, dino_model, method="dino")
  127. if 'concat' in method:
  128. # CLIP
  129. clip_real_train_embeddings = extract_all_embeddings(real_train_folder, clip_model, method="clip")
  130. clip_fake_train_embeddings = extract_all_embeddings(fake_train_folder, clip_model, method="clip")
  131. clip_real_test_embeddings = extract_all_embeddings(real_test_folder, clip_model, method="clip")
  132. clip_fake_test_embeddings = extract_all_embeddings(fake_test_folder, clip_model, method="clip")
  133. # DINO
  134. dino_real_train_embeddings = extract_all_embeddings(real_train_folder, dino_model, method="dino")
  135. dino_fake_train_embeddings = extract_all_embeddings(fake_train_folder, dino_model, method="dino")
  136. dino_real_test_embeddings = extract_all_embeddings(real_test_folder, dino_model, method="dino")
  137. dino_fake_test_embeddings = extract_all_embeddings(fake_test_folder, dino_model, method="dino")
  138. real_train_embeddings = dino_clip_concat(clip_real_train_embeddings, dino_real_train_embeddings)
  139. fake_train_embeddings = dino_clip_concat(clip_fake_train_embeddings, dino_fake_train_embeddings)
  140. real_test_embeddings = dino_clip_concat(clip_real_test_embeddings, dino_real_test_embeddings)
  141. fake_test_embeddings = dino_clip_concat(clip_fake_test_embeddings, dino_fake_test_embeddings)
  142. # Combine embeddings and create labels
  143. all_embeddings = np.vstack([
  144. real_train_embeddings,
  145. fake_train_embeddings,
  146. real_test_embeddings,
  147. fake_test_embeddings
  148. ])
  149. labels = (["Real Train"] * len(real_train_embeddings) +
  150. ["Fake Train"] * len(fake_train_embeddings) +
  151. ["Real Test"] * len(real_test_embeddings) +
  152. ["Fake Test"] * len(fake_test_embeddings))
  153. # Perform t-SNE on the combined embeddings
  154. tsne = TSNE(n_components=3, random_state=42, perplexity=perplexity)
  155. tsne_results = tsne.fit_transform(all_embeddings)
  156. # Create a DataFrame for Plotly
  157. tsne_df = pd.DataFrame(tsne_results, columns=['x', 'y', 'z'])
  158. tsne_df['label'] = labels
  159. # Create the 3D scatter plot
  160. fig = px.scatter_3d(tsne_df, x='x', y='y', z='z', color='label', title=title,
  161. color_discrete_map={
  162. "Real Train": "blue",
  163. "Fake Train": "red",
  164. "Real Test": "green",
  165. "Fake Test": "orange"
  166. },
  167. labels={'label': 'Image Type'})
  168. fig.update_traces(marker=dict(size=5, opacity=0.8), selector=dict(mode='markers'))
  169. # Save the figure as an interactive HTML file
  170. fig.write_html(out_path)
  171. print(f"3D plot saved as {out_path}")
  172. fig.data = None
  173. fig.layout = None
  174. # Compute cosine similarities between original and noisy embeddings
  175. def compute_cosine_similarities(folder_path, model, noise_std=0.1):
  176. similarities = []
  177. labels = [] # 0 for real, 1 for fake
  178. for image_name in tqdm(os.listdir(folder_path), desc=f"Processing {folder_path} for Cosine Similarities"):
  179. image_path = os.path.join(folder_path, image_name)
  180. if os.path.isfile(image_path):
  181. original_embedding = extract_dino_embeddings(image_path, model, add_noise_flag=False, noise_std=noise_std)
  182. noisy_embedding = extract_dino_embeddings(image_path, model, add_noise_flag=True, noise_std=noise_std)
  183. similarity = np.dot(original_embedding, noisy_embedding) / (
  184. np.linalg.norm(original_embedding) * np.linalg.norm(noisy_embedding)
  185. )
  186. similarities.append(similarity)
  187. return np.array(similarities)
  188. # Visualize cosine similarities using t-SNE
  189. def visualize_cosine_similarities(real_similarities, fake_similarities, output_path, title, perplexity=40):
  190. all_similarities = np.hstack([real_similarities, fake_similarities])
  191. labels = ["Real"] * len(real_similarities) + ["Fake"] * len(fake_similarities)
  192. tsne = TSNE(n_components=1, perplexity=perplexity, random_state=42)
  193. tsne_results = tsne.fit_transform(all_similarities.reshape(-1, 1))
  194. plt.figure(figsize=(10, 8))
  195. for label, color in zip(["Real", "Fake"], ["blue", "red"]):
  196. indices = [i for i, l in enumerate(labels) if l == label]
  197. plt.scatter(tsne_results[indices, 0], np.zeros_like(tsne_results[indices, 0]), label=label, alpha=0.7, s=50, c=color)
  198. plt.title(title)
  199. plt.legend()
  200. plt.savefig(output_path)
  201. plt.close()
  202. def align_embeddings(clip_embeddings, dino_embeddings):
  203. if clip_embeddings.shape[1] < dino_embeddings.shape[1]:
  204. pca = PCA(n_components=clip_embeddings.shape[1])
  205. dino_embeddings_aligned = pca.fit_transform(dino_embeddings)
  206. return clip_embeddings, dino_embeddings_aligned
  207. elif clip_embeddings.shape[1] > dino_embeddings.shape[1]:
  208. pca = PCA(n_components=dino_embeddings.shape[1])
  209. clip_embeddings_aligned = pca.fit_transform(clip_embeddings)
  210. return clip_embeddings_aligned, dino_embeddings
  211. else:
  212. return clip_embeddings, dino_embeddings
  213. # Function to compute weighted concatenation of embeddings
  214. def dino_clip_concat(clip_embeddings, dino_embeddings, clip_weight=0.5, dino_weight=0.5):
  215. clip_embeddings, dino_embeddings = align_embeddings(clip_embeddings, dino_embeddings)
  216. return clip_weight * clip_embeddings + dino_weight * dino_embeddings
  217. def tsne_visualize(real_folder, fake_folder, output_dir, finetuned_dino_path=None, dim=2):
  218. os.makedirs(output_dir, exist_ok=True)
  219. # Use fine-tuned DINO if a path is provided
  220. if finetuned_dino_path:
  221. print("Using fine-tuned DINO model.")
  222. dino_model = load_finetuned_dino(dino_variant, finetuned_dino_path)
  223. dino_model.eval().to(device)
  224. else:
  225. dino_variant = "dinov2_vitl14" # Change to the desired DINO variant
  226. dino_model = torch.hub.load('facebookresearch/dinov2', dino_variant)
  227. dino_model.eval().to(device)
  228. print("Using original DINO model.")
  229. if dim == 2:
  230. visualize_embeddings = visualize_embeddings_2d
  231. file_type = 'png'
  232. elif dim == 3:
  233. visualize_embeddings = visualize_embeddings_3d
  234. file_type = 'html'
  235. # CLIP-VITL14 embeddings
  236. real_clip_embeddings = extract_all_embeddings(real_folder, clip_model, method="clip")
  237. fake_clip_embeddings = extract_all_embeddings(fake_folder, clip_model, method="clip")
  238. visualize_embeddings(real_clip_embeddings, fake_clip_embeddings, "CLIP-VITL14 Embeddings", os.path.join(output_dir, f"{dim}d_clip_embeddings.{file_type}"))
  239. # DINOv2-VITL14 embeddings
  240. real_dino_embeddings = extract_all_embeddings(real_folder, dino_model, method="dino")
  241. fake_dino_embeddings = extract_all_embeddings(fake_folder, dino_model, method="dino")
  242. visualize_embeddings(real_dino_embeddings, fake_dino_embeddings, "DINOv2-VITL14 Embeddings", os.path.join(output_dir, f"{dim}d_dino_embeddings.{file_type}"))
  243. # Weighted concatenation of CLIP and DINO embeddings
  244. real_combined_embeddings = dino_clip_concat(real_clip_embeddings, real_dino_embeddings)
  245. fake_combined_embeddings = dino_clip_concat(fake_clip_embeddings, fake_dino_embeddings)
  246. visualize_embeddings(real_combined_embeddings, fake_combined_embeddings, "CLIP+DINO Combined Embeddings", os.path.join(output_dir, f"{dim}d_clip_dino_embeddings.{file_type}"))
  247. # # RIGID: DINOv2 original vs noisy embeddings (real)
  248. # real_dino_noisy_embeddings = extract_all_embeddings(real_folder, dino_model, method="dino", add_noise_flag=True)
  249. # visualize_embeddings(real_dino_embeddings, real_dino_noisy_embeddings, "Real: DINOv2 Original vs Noisy", os.path.join(output_dir, "real_dino_noisy.{file_type}"))
  250. # # RIGID: DINOv2 original vs noisy embeddings (fake)
  251. # fake_dino_noisy_embeddings = extract_all_embeddings(fake_folder, dino_model, method="dino", add_noise_flag=True)
  252. # visualize_embeddings(fake_dino_embeddings, fake_dino_noisy_embeddings, "Fake: DINOv2 Original vs Noisy", os.path.join(output_dir, "fake_dino_noisy.{file_type}"))
  253. # # RIGID: Cosine similarities
  254. # real_cosine_similarities = compute_cosine_similarities(real_folder, dino_model)
  255. # fake_cosine_similarities = compute_cosine_similarities(fake_folder, dino_model)
  256. # visualize_cosine_similarities(real_cosine_similarities, fake_cosine_similarities, os.path.join(output_dir, "cosine_similarities.png"), "Cosine Similarities of Real vs Fake")
  257. if __name__ == "__main__":
  258. # fake_test_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/GenImage-tiny-all/1_fake"
  259. # real_test_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/GenImage-tiny-all/0_real"
  260. # fake_folder = "../../datasets/GenImage-tiny-all/1_fake"
  261. # real_folder = "../../datasets/GenImage-tiny-all/0_real"
  262. # fake_folder = "../../datasets/ArtiFact_test_small/1_fake"
  263. # real_folder = "../../datasets/ArtiFact_test_small/0_real"
  264. # fake_train_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/IMAGINET_train_all/1_fake"
  265. # real_train_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/IMAGINET_train_all/0_real"
  266. fake_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/GenImage-tiny-all/imagenet_all/val/1_fake"
  267. real_folder = "/media/external_16TB_1/amirtaha_amanzadi/datasets/GenImage-tiny-all/imagenet_all/val/0_real"
  268. output_dir = "./TSNE/GenImage-tiny" # Replace with the desired output directory for plots
  269. # tsne_visualize(real_folder, fake_folder, output_dir, finetuned_dino_path='./saved_models/dino_ep_17_acc_0.5598_ap_0.5360.pth')
  270. # tsne_visualize(real_folder, fake_folder, output_dir, dim=2)
  271. tsne_visualize(real_folder, fake_folder, output_dir, dim=3)
  272. # visualize_embeddings_4dirs_3d(
  273. # real_train_folder, fake_train_folder, real_test_folder, fake_test_folder,
  274. # title="3D t-SNE for Real/Fake Train and Test Data",
  275. # out_path=os.path.join(output_dir, "3d_clip_ft.html"),
  276. # method='clip'
  277. # )
  278. # visualize_embeddings_4dirs_3d(
  279. # real_train_folder, fake_train_folder, real_test_folder, fake_test_folder,
  280. # title="3D t-SNE for Real/Fake Train and Test Data",
  281. # out_path=os.path.join(output_dir, "3d_dino_ft.html"),
  282. # method='dino'
  283. # )
  284. # visualize_embeddings_4dirs_3d(
  285. # real_train_folder, fake_train_folder, real_test_folder, fake_test_folder,
  286. # title="3D t-SNE for Real/Fake Train and Test Data",
  287. # out_path=os.path.join(output_dir, "3d_clip_dino_ft.html"),
  288. # method='concat'
  289. # )