| 
														 | 
														 | 
														 | 
														 | 
														 | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														import os | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														import os.path as osp | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														from utils.transform import * | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														from torch.utils.data import Dataset | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														from torchvision import transforms | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														# KavSir-SEG Dataset | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														class kvasir_SEG(Dataset): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def __init__(self, root, data2_dir, mode='train', transform=None): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        super(kvasir_SEG, self).__init__() | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        data_path = osp.join(root, data2_dir) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.imglist = [] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.gtlist = [] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        datalist = os.listdir(osp.join(data_path, 'images')) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        for data in datalist: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            self.imglist.append(osp.join(data_path+'/images', data)) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            self.gtlist.append(osp.join(data_path+'/masks', data)) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        if transform is None: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            if mode == 'train': | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														               transform = transforms.Compose([ | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   Resize((320,320 )), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   RandomHorizontalFlip(), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   RandomVerticalFlip(), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   RandomRotation(90), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   RandomZoom((0.9, 1.1)), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   Translation(10), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   RandomCrop((256, 256)), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   ToTensor(), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														               ]) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            elif mode == 'valid' or mode == 'test': | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                transform = transforms.Compose([ | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   Resize((320, 320)), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                   ToTensor(), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														               ]) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.transform = transform | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def __getitem__(self, index): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        img_path = self.imglist[index] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        gt_path = self.gtlist[index] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        img = Image.open(img_path).convert('RGB') | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        gt = Image.open(gt_path).convert('L') | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        data = {'image': img, 'label': gt} | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        if self.transform: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            data = self.transform(data) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        return data | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def __len__(self): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        return len(self.imglist) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														""" | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														class test_dataset: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def __init__(self, image_root, gt_root, testsize): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.testsize = testsize | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.images = sorted(self.images) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.gts = sorted(self.gts) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.transform = transforms.Compose([ | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            transforms.Resize((self.testsize, self.testsize)), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            transforms.ToTensor(), | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            #transforms.Normalize([0.485, 0.456, 0.406], | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														                                 #[0.229, 0.224, 0.225])]) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.gt_transform = transforms.ToTensor() | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.size = len(self.images) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.index = 0 | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def load_data(self): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        image = self.rgb_loader(self.images[self.index]) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        image = self.transform(image).unsqueeze(0) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        gt = self.binary_loader(self.gts[self.index]) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        name = self.images[self.index].split('/')[-1] | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        if name.endswith('.jpg'): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            name = name.split('.jpg')[0] + '.png' | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        self.index += 1 | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        return image, gt, name | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def rgb_loader(self, path): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        with open(path, 'rb') as f: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            img = Image.open(f) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            return img.convert('RGB') | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														    def binary_loader(self, path): | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														        with open(path, 'rb') as f: | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            img = Image.open(f) | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														            return img.convert('L') | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														
  | 
													
													
												
													
														 | 
														 | 
														 | 
														 | 
														 | 
														""" |