import os import random from PIL import Image, ImageFilter from io import BytesIO import torch from torch.utils.data import Dataset import torchvision.transforms as transforms class ImageDataset(Dataset): def __init__(self, folder, label, transform, is_train=False, aug_prob=0, jpeg_qf=None, gaussian_blur=None): self.folder = folder self.label = label self.transform = transform self.is_train = is_train self.aug_prob = aug_prob self.jpeg_qf = jpeg_qf self.gaussian_blur = gaussian_blur self.images = [ os.path.join(folder, img) for img in os.listdir(folder) if os.path.isfile(os.path.join(folder, img)) ] self.jpeg_choices = [90, 75, 50] self.blur_choices = [1.0, 2.0, 3.0] def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] try: image = Image.open(image_path).convert("RGB") # === TRAIN MODE: apply JPEG/blur randomly === if self.is_train: if random.random() < self.aug_prob: qf = random.choice(self.jpeg_choices) from io import BytesIO buffer = BytesIO() image.save(buffer, format="JPEG", quality=qf) buffer.seek(0) image = Image.open(buffer).convert("RGB") if random.random() < self.aug_prob: sigma = random.choice(self.blur_choices) image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) # === EVAL MODE: apply deterministic JPEG/blur if given === else: if self.jpeg_qf: from io import BytesIO qf = random.choice(self.jpeg_qf) buffer = BytesIO() image.save(buffer, format="JPEG", quality=qf) buffer.seek(0) image = Image.open(buffer).convert("RGB") if self.gaussian_blur: sigma = random.choice(self.gaussian_blur) image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) image = self.transform(image) return image, self.label except Exception as e: print(f"Skipping corrupted image {image_path}: {e}") raise IndexError # class ImageDataset(Dataset): # def __init__(self, folder, label, transform, jpeg_qf=None, gaussian_blur=None): # self.folder = folder # self.label = label # self.transform = transform # self.jpeg_qf = jpeg_qf # self.gaussian_blur = gaussian_blur # self.images = [ # os.path.join(folder, img) for img in os.listdir(folder) # if os.path.isfile(os.path.join(folder, img)) # ] # def __len__(self): # return len(self.images) # def __getitem__(self, idx): # image_path = self.images[idx] # try: # image = Image.open(image_path).convert("RGB") # # --- JPEG Compression --- # if self.jpeg_qf: # buffer = BytesIO() # qf = random.choice(self.jpeg_qf) # image.save(buffer, format="JPEG", quality=qf) # buffer.seek(0) # image = Image.open(buffer).convert("RGB") # # --- Gaussian Blur --- # if self.gaussian_blur: # sigma = random.choice(self.gaussian_blur) # image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) # image = self.transform(image) # return image, self.label # except Exception as e: # print(f"Skipping corrupted image {image_path}: {e}") # raise IndexError