You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 1.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. from PIL import Image
  3. import torch.utils.data as data
  4. import torchvision.transforms as transforms
  5. class test_dataset:
  6. def __init__(self, image_root, gt_root, testsize):
  7. self.testsize = testsize
  8. self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
  9. self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
  10. self.images = sorted(self.images)
  11. self.gts = sorted(self.gts)
  12. self.transform = transforms.Compose([
  13. transforms.Resize((self.testsize, self.testsize)),
  14. transforms.ToTensor()
  15. #transforms.Normalize([0.485, 0.456, 0.406],
  16. #[0.229, 0.224, 0.225])
  17. ])
  18. self.gt_transform = transforms.ToTensor()
  19. self.size = len(self.images)
  20. self.index = 0
  21. def load_data(self):
  22. image = self.rgb_loader(self.images[self.index])
  23. image = self.transform(image).unsqueeze(0)
  24. gt = self.binary_loader(self.gts[self.index])
  25. name = self.images[self.index].split('/')[-1]
  26. if name.endswith('.jpg'):
  27. name = name.split('.jpg')[0] + '.png'
  28. self.index += 1
  29. return image, gt, name
  30. def rgb_loader(self, path):
  31. with open(path, 'rb') as f:
  32. img = Image.open(f)
  33. return img.convert('RGB')
  34. def binary_loader(self, path):
  35. with open(path, 'rb') as f:
  36. img = Image.open(f)
  37. return img.convert('L')