You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 1.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Adopted from PRAnet
  2. import os
  3. from PIL import Image
  4. import torch.utils.data as data
  5. import torchvision.transforms as transforms
  6. class test_dataset:
  7. def __init__(self, image_root, gt_root, testsize):
  8. self.testsize = testsize
  9. self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
  10. self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
  11. self.images = sorted(self.images)
  12. self.gts = sorted(self.gts)
  13. self.transform = transforms.Compose([
  14. transforms.Resize((self.testsize, self.testsize)),
  15. transforms.ToTensor()
  16. #transforms.Normalize([0.485, 0.456, 0.406],
  17. #[0.229, 0.224, 0.225])
  18. ])
  19. self.gt_transform = transforms.ToTensor()
  20. self.size = len(self.images)
  21. self.index = 0
  22. def load_data(self):
  23. image = self.rgb_loader(self.images[self.index])
  24. image = self.transform(image).unsqueeze(0)
  25. gt = self.binary_loader(self.gts[self.index])
  26. name = self.images[self.index].split('/')[-1]
  27. if name.endswith('.jpg'):
  28. name = name.split('.jpg')[0] + '.png'
  29. self.index += 1
  30. return image, gt, name
  31. def rgb_loader(self, path):
  32. with open(path, 'rb') as f:
  33. img = Image.open(f)
  34. return img.convert('RGB')
  35. def binary_loader(self, path):
  36. with open(path, 'rb') as f:
  37. img = Image.open(f)
  38. return img.convert('L')