| #Adopted from ACSNet | |||||
| import torch | |||||
| from torch.utils.data import Dataset, DataLoader | |||||
| import torchvision.transforms.functional as F | |||||
| from torchvision import transforms | |||||
| import os | |||||
| from PIL import Image | |||||
| import os.path as osp | |||||
| from utils.transform import * | |||||
| # EndoScene Dataset | |||||
| class EndoScene(Dataset): | |||||
| def __init__(self, root, data_dir, mode='train', transform=None): | |||||
| super(EndoScene, self).__init__() | |||||
| data_path1 = osp.join(root, data_dir) | |||||
| #data_path2 = osp.join(root, data_dir) + '/CVC-612' | |||||
| self.imglist = [] | |||||
| self.gtlist = [] | |||||
| datalist1 = os.listdir(osp.join(data_path1, 'image')) | |||||
| for data1 in datalist1: | |||||
| self.imglist.append(osp.join(data_path1 + '/image', data1)) | |||||
| self.gtlist.append(osp.join(data_path1 + '/gtpolyp', data1)) | |||||
| #datalist2 = os.listdir(osp.join(data_path2, 'image')) | |||||
| #for data2 in datalist2: | |||||
| #self.imglist.append(osp.join(data_path2 + '/image', data2)) | |||||
| #self.gtlist.append(osp.join(data_path2 + '/gtpolyp', data2)) | |||||
| if transform is None: | |||||
| if mode == 'train': | |||||
| transform = transforms.Compose([ | |||||
| Resize((320, 320)), | |||||
| RandomHorizontalFlip(), | |||||
| RandomVerticalFlip(), | |||||
| RandomRotation(90), | |||||
| RandomZoom((0.9, 1.1)), | |||||
| #Translation(10), | |||||
| RandomCrop((256, 256)), | |||||
| #transforms.Normalize([0.485, 0.456, 0.406], | |||||
| #[0.229, 0.224, 0.225])] | |||||
| ToTensor(), | |||||
| ]) | |||||
| elif mode == 'valid' or mode == 'test': | |||||
| transform = transforms.Compose([ | |||||
| Resize((320, 320)), | |||||
| ToTensor(), | |||||
| ]) | |||||
| self.transform = transform | |||||
| def __getitem__(self, index): | |||||
| img_path = self.imglist[index] | |||||
| gt_path = self.gtlist[index] | |||||
| img = Image.open(img_path).convert('RGB') | |||||
| gt = Image.open(gt_path).convert('L') | |||||
| data = {'image': img, 'label': gt} | |||||
| if self.transform: | |||||
| data = self.transform(data) | |||||
| return data | |||||
| def __len__(self): | |||||
| return len(self.imglist) |