|
123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Adopted from PRAnet
- import os
- from PIL import Image
- import torch.utils.data as data
- import torchvision.transforms as transforms
-
- class test_dataset:
- def __init__(self, image_root, gt_root, testsize):
- self.testsize = testsize
- self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
- self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
- self.images = sorted(self.images)
- self.gts = sorted(self.gts)
- self.transform = transforms.Compose([
- transforms.Resize((self.testsize, self.testsize)),
- transforms.ToTensor()
- #transforms.Normalize([0.485, 0.456, 0.406],
- #[0.229, 0.224, 0.225])
- ])
- self.gt_transform = transforms.ToTensor()
- self.size = len(self.images)
- self.index = 0
-
- def load_data(self):
- image = self.rgb_loader(self.images[self.index])
- image = self.transform(image).unsqueeze(0)
- gt = self.binary_loader(self.gts[self.index])
- name = self.images[self.index].split('/')[-1]
- if name.endswith('.jpg'):
- name = name.split('.jpg')[0] + '.png'
- self.index += 1
- return image, gt, name
-
- def rgb_loader(self, path):
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('RGB')
-
- def binary_loader(self, path):
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('L')
|