import os from PIL import Image import torch.utils.data as data import torchvision.transforms as transforms class test_dataset: def __init__(self, image_root, gt_root, testsize): self.testsize = testsize self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] self.images = sorted(self.images) self.gts = sorted(self.gts) self.transform = transforms.Compose([ transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor() #transforms.Normalize([0.485, 0.456, 0.406], #[0.229, 0.224, 0.225]) ]) self.gt_transform = transforms.ToTensor() self.size = len(self.images) self.index = 0 def load_data(self): image = self.rgb_loader(self.images[self.index]) image = self.transform(image).unsqueeze(0) gt = self.binary_loader(self.gts[self.index]) name = self.images[self.index].split('/')[-1] if name.endswith('.jpg'): name = name.split('.jpg')[0] + '.png' self.index += 1 return image, gt, name def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('L')