You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import random
  3. from PIL import Image, ImageFilter
  4. from io import BytesIO
  5. import torch
  6. from torch.utils.data import Dataset
  7. import torchvision.transforms as transforms
  8. class ImageDataset(Dataset):
  9. def __init__(self, folder, label, transform, is_train=False, aug_prob=0, jpeg_qf=None, gaussian_blur=None):
  10. self.folder = folder
  11. self.label = label
  12. self.transform = transform
  13. self.is_train = is_train
  14. self.aug_prob = aug_prob
  15. self.jpeg_qf = jpeg_qf
  16. self.gaussian_blur = gaussian_blur
  17. self.images = [
  18. os.path.join(folder, img) for img in os.listdir(folder)
  19. if os.path.isfile(os.path.join(folder, img))
  20. ]
  21. self.jpeg_choices = [90, 75, 50]
  22. self.blur_choices = [1.0, 2.0, 3.0]
  23. def __len__(self):
  24. return len(self.images)
  25. def __getitem__(self, idx):
  26. image_path = self.images[idx]
  27. try:
  28. image = Image.open(image_path).convert("RGB")
  29. # === TRAIN MODE: apply JPEG/blur randomly ===
  30. if self.is_train:
  31. if random.random() < self.aug_prob:
  32. qf = random.choice(self.jpeg_choices)
  33. from io import BytesIO
  34. buffer = BytesIO()
  35. image.save(buffer, format="JPEG", quality=qf)
  36. buffer.seek(0)
  37. image = Image.open(buffer).convert("RGB")
  38. if random.random() < self.aug_prob:
  39. sigma = random.choice(self.blur_choices)
  40. image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
  41. # === EVAL MODE: apply deterministic JPEG/blur if given ===
  42. else:
  43. if self.jpeg_qf:
  44. from io import BytesIO
  45. qf = random.choice(self.jpeg_qf)
  46. buffer = BytesIO()
  47. image.save(buffer, format="JPEG", quality=qf)
  48. buffer.seek(0)
  49. image = Image.open(buffer).convert("RGB")
  50. if self.gaussian_blur:
  51. sigma = random.choice(self.gaussian_blur)
  52. image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
  53. image = self.transform(image)
  54. return image, self.label
  55. except Exception as e:
  56. print(f"Skipping corrupted image {image_path}: {e}")
  57. raise IndexError
  58. # class ImageDataset(Dataset):
  59. # def __init__(self, folder, label, transform, jpeg_qf=None, gaussian_blur=None):
  60. # self.folder = folder
  61. # self.label = label
  62. # self.transform = transform
  63. # self.jpeg_qf = jpeg_qf
  64. # self.gaussian_blur = gaussian_blur
  65. # self.images = [
  66. # os.path.join(folder, img) for img in os.listdir(folder)
  67. # if os.path.isfile(os.path.join(folder, img))
  68. # ]
  69. # def __len__(self):
  70. # return len(self.images)
  71. # def __getitem__(self, idx):
  72. # image_path = self.images[idx]
  73. # try:
  74. # image = Image.open(image_path).convert("RGB")
  75. # # --- JPEG Compression ---
  76. # if self.jpeg_qf:
  77. # buffer = BytesIO()
  78. # qf = random.choice(self.jpeg_qf)
  79. # image.save(buffer, format="JPEG", quality=qf)
  80. # buffer.seek(0)
  81. # image = Image.open(buffer).convert("RGB")
  82. # # --- Gaussian Blur ---
  83. # if self.gaussian_blur:
  84. # sigma = random.choice(self.gaussian_blur)
  85. # image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
  86. # image = self.transform(image)
  87. # return image, self.label
  88. # except Exception as e:
  89. # print(f"Skipping corrupted image {image_path}: {e}")
  90. # raise IndexError