import torch from torch.utils.data import Dataset, DataLoader import torchvision.transforms.functional as F from torchvision import transforms import os from PIL import Image import os.path as osp from utils.transform import * # EndoScene Dataset class EndoScene(Dataset): def __init__(self, root, data_dir, mode='train', transform=None): super(EndoScene, self).__init__() data_path1 = osp.join(root, data_dir) #data_path2 = osp.join(root, data_dir) + '/CVC-612' self.imglist = [] self.gtlist = [] datalist1 = os.listdir(osp.join(data_path1, 'image')) for data1 in datalist1: self.imglist.append(osp.join(data_path1 + '/image', data1)) self.gtlist.append(osp.join(data_path1 + '/gtpolyp', data1)) #datalist2 = os.listdir(osp.join(data_path2, 'image')) #for data2 in datalist2: #self.imglist.append(osp.join(data_path2 + '/image', data2)) #self.gtlist.append(osp.join(data_path2 + '/gtpolyp', data2)) if transform is None: if mode == 'train': transform = transforms.Compose([ Resize((320, 320)), RandomHorizontalFlip(), RandomVerticalFlip(), RandomRotation(90), RandomZoom((0.9, 1.1)), #Translation(10), RandomCrop((256, 256)), #transforms.Normalize([0.485, 0.456, 0.406], # [0.229, 0.224, 0.225]), ToTensor(), ]) elif mode == 'valid' or mode == 'test': transform = transforms.Compose([ Resize((320, 320)), ToTensor(), ]) self.transform = transform def __getitem__(self, index): img_path = self.imglist[index] gt_path = self.gtlist[index] img = Image.open(img_path).convert('RGB') gt = Image.open(gt_path).convert('L') data = {'image': img, 'label': gt} if self.transform: data = self.transform(data) return data def __len__(self): return len(self.imglist)