You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EndoScene.py 2.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. import torchvision.transforms.functional as F
  4. from torchvision import transforms
  5. import os
  6. from PIL import Image
  7. import os.path as osp
  8. from utils.transform import *
  9. # EndoScene Dataset
  10. class EndoScene(Dataset):
  11. def __init__(self, root, data_dir, mode='train', transform=None):
  12. super(EndoScene, self).__init__()
  13. data_path1 = osp.join(root, data_dir)
  14. #data_path2 = osp.join(root, data_dir) + '/CVC-612'
  15. self.imglist = []
  16. self.gtlist = []
  17. datalist1 = os.listdir(osp.join(data_path1, 'image'))
  18. for data1 in datalist1:
  19. self.imglist.append(osp.join(data_path1 + '/image', data1))
  20. self.gtlist.append(osp.join(data_path1 + '/gtpolyp', data1))
  21. #datalist2 = os.listdir(osp.join(data_path2, 'image'))
  22. #for data2 in datalist2:
  23. #self.imglist.append(osp.join(data_path2 + '/image', data2))
  24. #self.gtlist.append(osp.join(data_path2 + '/gtpolyp', data2))
  25. if transform is None:
  26. if mode == 'train':
  27. transform = transforms.Compose([
  28. Resize((320, 320)),
  29. RandomHorizontalFlip(),
  30. RandomVerticalFlip(),
  31. RandomRotation(90),
  32. RandomZoom((0.9, 1.1)),
  33. #Translation(10),
  34. RandomCrop((256, 256)),
  35. #transforms.Normalize([0.485, 0.456, 0.406],
  36. # [0.229, 0.224, 0.225]),
  37. ToTensor(),
  38. ])
  39. elif mode == 'valid' or mode == 'test':
  40. transform = transforms.Compose([
  41. Resize((320, 320)),
  42. ToTensor(),
  43. ])
  44. self.transform = transform
  45. def __getitem__(self, index):
  46. img_path = self.imglist[index]
  47. gt_path = self.gtlist[index]
  48. img = Image.open(img_path).convert('RGB')
  49. gt = Image.open(gt_path).convert('L')
  50. data = {'image': img, 'label': gt}
  51. if self.transform:
  52. data = self.transform(data)
  53. return data
  54. def __len__(self):
  55. return len(self.imglist)