You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EndoScene.py 2.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #Adopted from ACSNet
  2. import torch
  3. from torch.utils.data import Dataset, DataLoader
  4. import torchvision.transforms.functional as F
  5. from torchvision import transforms
  6. import os
  7. from PIL import Image
  8. import os.path as osp
  9. from utils.transform import *
  10. # EndoScene Dataset
  11. class EndoScene(Dataset):
  12. def __init__(self, root, data_dir, mode='train', transform=None):
  13. super(EndoScene, self).__init__()
  14. data_path1 = osp.join(root, data_dir)
  15. #data_path2 = osp.join(root, data_dir) + '/CVC-612'
  16. self.imglist = []
  17. self.gtlist = []
  18. datalist1 = os.listdir(osp.join(data_path1, 'image'))
  19. for data1 in datalist1:
  20. self.imglist.append(osp.join(data_path1 + '/image', data1))
  21. self.gtlist.append(osp.join(data_path1 + '/gtpolyp', data1))
  22. #datalist2 = os.listdir(osp.join(data_path2, 'image'))
  23. #for data2 in datalist2:
  24. #self.imglist.append(osp.join(data_path2 + '/image', data2))
  25. #self.gtlist.append(osp.join(data_path2 + '/gtpolyp', data2))
  26. if transform is None:
  27. if mode == 'train':
  28. transform = transforms.Compose([
  29. Resize((320, 320)),
  30. RandomHorizontalFlip(),
  31. RandomVerticalFlip(),
  32. RandomRotation(90),
  33. RandomZoom((0.9, 1.1)),
  34. #Translation(10),
  35. RandomCrop((256, 256)),
  36. #transforms.Normalize([0.485, 0.456, 0.406],
  37. #[0.229, 0.224, 0.225])]
  38. ToTensor(),
  39. ])
  40. elif mode == 'valid' or mode == 'test':
  41. transform = transforms.Compose([
  42. Resize((320, 320)),
  43. ToTensor(),
  44. ])
  45. self.transform = transform
  46. def __getitem__(self, index):
  47. img_path = self.imglist[index]
  48. gt_path = self.gtlist[index]
  49. img = Image.open(img_path).convert('RGB')
  50. gt = Image.open(gt_path).convert('L')
  51. data = {'image': img, 'label': gt}
  52. if self.transform:
  53. data = self.transform(data)
  54. return data
  55. def __len__(self):
  56. return len(self.imglist)