You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kvasir_SEG.py 3.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #Adopted from ACSNet
  2. import os
  3. import os.path as osp
  4. from utils.transform import *
  5. from torch.utils.data import Dataset
  6. from torchvision import transforms
  7. # KavSir-SEG Dataset
  8. class kvasir_SEG(Dataset):
  9. def __init__(self, root, data2_dir, mode='train', transform=None):
  10. super(kvasir_SEG, self).__init__()
  11. data_path = osp.join(root, data2_dir)
  12. self.imglist = []
  13. self.gtlist = []
  14. datalist = os.listdir(osp.join(data_path, 'images'))
  15. for data in datalist:
  16. self.imglist.append(osp.join(data_path+'/images', data))
  17. self.gtlist.append(osp.join(data_path+'/masks', data))
  18. if transform is None:
  19. if mode == 'train':
  20. transform = transforms.Compose([
  21. Resize((320,320 )),
  22. RandomHorizontalFlip(),
  23. RandomVerticalFlip(),
  24. RandomRotation(90),
  25. RandomZoom((0.9, 1.1)),
  26. Translation(10),
  27. RandomCrop((256, 256)),
  28. ToTensor(),
  29. ])
  30. elif mode == 'valid' or mode == 'test':
  31. transform = transforms.Compose([
  32. Resize((320, 320)),
  33. ToTensor(),
  34. ])
  35. self.transform = transform
  36. def __getitem__(self, index):
  37. img_path = self.imglist[index]
  38. gt_path = self.gtlist[index]
  39. img = Image.open(img_path).convert('RGB')
  40. gt = Image.open(gt_path).convert('L')
  41. data = {'image': img, 'label': gt}
  42. if self.transform:
  43. data = self.transform(data)
  44. return data
  45. def __len__(self):
  46. return len(self.imglist)
  47. class test_dataset:
  48. def __init__(self, image_root, gt_root, testsize):
  49. self.testsize = testsize
  50. self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
  51. self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
  52. self.images = sorted(self.images)
  53. self.gts = sorted(self.gts)
  54. self.transform = transforms.Compose([
  55. transforms.Resize((self.testsize, self.testsize)),
  56. transforms.ToTensor(),
  57. #transforms.Normalize([0.485, 0.456, 0.406],
  58. #[0.229, 0.224, 0.225])])
  59. self.gt_transform = transforms.ToTensor()
  60. self.size = len(self.images)
  61. self.index = 0
  62. def load_data(self):
  63. image = self.rgb_loader(self.images[self.index])
  64. image = self.transform(image).unsqueeze(0)
  65. gt = self.binary_loader(self.gts[self.index])
  66. name = self.images[self.index].split('/')[-1]
  67. if name.endswith('.jpg'):
  68. name = name.split('.jpg')[0] + '.png'
  69. self.index += 1
  70. return image, gt, name
  71. def rgb_loader(self, path):
  72. with open(path, 'rb') as f:
  73. img = Image.open(f)
  74. return img.convert('RGB')
  75. def binary_loader(self, path):
  76. with open(path, 'rb') as f:
  77. img = Image.open(f)
  78. return img.convert('L')