| @@ -1,66 +0,0 @@ | |||
| #Adopted from ACSNet | |||
| import torch | |||
| from torch.utils.data import Dataset, DataLoader | |||
| import torchvision.transforms.functional as F | |||
| from torchvision import transforms | |||
| import os | |||
| from PIL import Image | |||
| import os.path as osp | |||
| from utils.transform import * | |||
| # EndoScene Dataset | |||
| class EndoScene(Dataset): | |||
| def __init__(self, root, data_dir, mode='train', transform=None): | |||
| super(EndoScene, self).__init__() | |||
| data_path1 = osp.join(root, data_dir) | |||
| #data_path2 = osp.join(root, data_dir) + '/CVC-612' | |||
| self.imglist = [] | |||
| self.gtlist = [] | |||
| datalist1 = os.listdir(osp.join(data_path1, 'image')) | |||
| for data1 in datalist1: | |||
| self.imglist.append(osp.join(data_path1 + '/image', data1)) | |||
| self.gtlist.append(osp.join(data_path1 + '/gtpolyp', data1)) | |||
| #datalist2 = os.listdir(osp.join(data_path2, 'image')) | |||
| #for data2 in datalist2: | |||
| #self.imglist.append(osp.join(data_path2 + '/image', data2)) | |||
| #self.gtlist.append(osp.join(data_path2 + '/gtpolyp', data2)) | |||
| if transform is None: | |||
| if mode == 'train': | |||
| transform = transforms.Compose([ | |||
| Resize((320, 320)), | |||
| RandomHorizontalFlip(), | |||
| RandomVerticalFlip(), | |||
| RandomRotation(90), | |||
| RandomZoom((0.9, 1.1)), | |||
| #Translation(10), | |||
| RandomCrop((256, 256)), | |||
| #transforms.Normalize([0.485, 0.456, 0.406], | |||
| #[0.229, 0.224, 0.225])] | |||
| ToTensor(), | |||
| ]) | |||
| elif mode == 'valid' or mode == 'test': | |||
| transform = transforms.Compose([ | |||
| Resize((320, 320)), | |||
| ToTensor(), | |||
| ]) | |||
| self.transform = transform | |||
| def __getitem__(self, index): | |||
| img_path = self.imglist[index] | |||
| gt_path = self.gtlist[index] | |||
| img = Image.open(img_path).convert('RGB') | |||
| gt = Image.open(gt_path).convert('L') | |||
| data = {'image': img, 'label': gt} | |||
| if self.transform: | |||
| data = self.transform(data) | |||
| return data | |||
| def __len__(self): | |||
| return len(self.imglist) | |||