1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import os
- import os.path as osp
- from utils.transform import *
- from torch.utils.data import Dataset
- from torchvision import transforms
-
-
- # KavSir-SEG Dataset
- class kvasir_SEG(Dataset):
- def __init__(self, root, data2_dir, mode='train', transform=None):
- super(kvasir_SEG, self).__init__()
- data_path = osp.join(root, data2_dir)
- self.imglist = []
- self.gtlist = []
-
- datalist = os.listdir(osp.join(data_path, 'images'))
- for data in datalist:
- self.imglist.append(osp.join(data_path+'/images', data))
- self.gtlist.append(osp.join(data_path+'/masks', data))
-
- if transform is None:
- if mode == 'train':
- transform = transforms.Compose([
- Resize((320,320 )),
- RandomHorizontalFlip(),
- RandomVerticalFlip(),
- RandomRotation(90),
- RandomZoom((0.9, 1.1)),
- Translation(10),
- RandomCrop((256, 256)),
- ToTensor(),
-
- ])
- elif mode == 'valid' or mode == 'test':
- transform = transforms.Compose([
- Resize((320, 320)),
- ToTensor(),
- ])
- self.transform = transform
-
- def __getitem__(self, index):
- img_path = self.imglist[index]
- gt_path = self.gtlist[index]
- img = Image.open(img_path).convert('RGB')
- gt = Image.open(gt_path).convert('L')
- data = {'image': img, 'label': gt}
- if self.transform:
- data = self.transform(data)
-
- return data
-
- def __len__(self):
- return len(self.imglist)
- """
-
- class test_dataset:
- def __init__(self, image_root, gt_root, testsize):
- self.testsize = testsize
- self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
- self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
- self.images = sorted(self.images)
- self.gts = sorted(self.gts)
- self.transform = transforms.Compose([
- transforms.Resize((self.testsize, self.testsize)),
- transforms.ToTensor(),
- #transforms.Normalize([0.485, 0.456, 0.406],
- #[0.229, 0.224, 0.225])])
- self.gt_transform = transforms.ToTensor()
- self.size = len(self.images)
- self.index = 0
-
- def load_data(self):
- image = self.rgb_loader(self.images[self.index])
- image = self.transform(image).unsqueeze(0)
- gt = self.binary_loader(self.gts[self.index])
- name = self.images[self.index].split('/')[-1]
- if name.endswith('.jpg'):
- name = name.split('.jpg')[0] + '.png'
- self.index += 1
- return image, gt, name
-
- def rgb_loader(self, path):
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('RGB')
-
- def binary_loader(self, path):
- with open(path, 'rb') as f:
- img = Image.open(f)
- return img.convert('L')
-
- """
|