| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import os
- import random
- from PIL import Image, ImageFilter
- from io import BytesIO
- import torch
- from torch.utils.data import Dataset
- import torchvision.transforms as transforms
-
-
- class ImageDataset(Dataset):
- def __init__(self, folder, label, transform, is_train=False, aug_prob=0, jpeg_qf=None, gaussian_blur=None):
- self.folder = folder
- self.label = label
- self.transform = transform
- self.is_train = is_train
- self.aug_prob = aug_prob
- self.jpeg_qf = jpeg_qf
- self.gaussian_blur = gaussian_blur
-
- self.images = [
- os.path.join(folder, img) for img in os.listdir(folder)
- if os.path.isfile(os.path.join(folder, img))
- ]
- self.jpeg_choices = [90, 75, 50]
- self.blur_choices = [1.0, 2.0, 3.0]
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- image_path = self.images[idx]
- try:
- image = Image.open(image_path).convert("RGB")
-
- # === TRAIN MODE: apply JPEG/blur randomly ===
- if self.is_train:
- if random.random() < self.aug_prob:
- qf = random.choice(self.jpeg_choices)
- from io import BytesIO
- buffer = BytesIO()
- image.save(buffer, format="JPEG", quality=qf)
- buffer.seek(0)
- image = Image.open(buffer).convert("RGB")
-
- if random.random() < self.aug_prob:
- sigma = random.choice(self.blur_choices)
- image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
-
- # === EVAL MODE: apply deterministic JPEG/blur if given ===
- else:
- if self.jpeg_qf:
- from io import BytesIO
- qf = random.choice(self.jpeg_qf)
- buffer = BytesIO()
- image.save(buffer, format="JPEG", quality=qf)
- buffer.seek(0)
- image = Image.open(buffer).convert("RGB")
-
- if self.gaussian_blur:
- sigma = random.choice(self.gaussian_blur)
- image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
-
- image = self.transform(image)
- return image, self.label
-
- except Exception as e:
- print(f"Skipping corrupted image {image_path}: {e}")
- raise IndexError
-
-
-
-
- # class ImageDataset(Dataset):
- # def __init__(self, folder, label, transform, jpeg_qf=None, gaussian_blur=None):
- # self.folder = folder
- # self.label = label
- # self.transform = transform
- # self.jpeg_qf = jpeg_qf
- # self.gaussian_blur = gaussian_blur
- # self.images = [
- # os.path.join(folder, img) for img in os.listdir(folder)
- # if os.path.isfile(os.path.join(folder, img))
- # ]
-
- # def __len__(self):
- # return len(self.images)
-
- # def __getitem__(self, idx):
- # image_path = self.images[idx]
- # try:
- # image = Image.open(image_path).convert("RGB")
-
- # # --- JPEG Compression ---
- # if self.jpeg_qf:
- # buffer = BytesIO()
- # qf = random.choice(self.jpeg_qf)
- # image.save(buffer, format="JPEG", quality=qf)
- # buffer.seek(0)
- # image = Image.open(buffer).convert("RGB")
-
- # # --- Gaussian Blur ---
- # if self.gaussian_blur:
- # sigma = random.choice(self.gaussian_blur)
- # image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
-
- # image = self.transform(image)
- # return image, self.label
- # except Exception as e:
- # print(f"Skipping corrupted image {image_path}: {e}")
- # raise IndexError
|