| @@ -0,0 +1,27 @@ | |||
| #Adopted from ACSNet | |||
| import models | |||
| import torch | |||
| import os | |||
| def generate_model(opt): | |||
| model = getattr(models, opt.model)(opt.nclasses) | |||
| if opt.use_gpu: | |||
| model.cuda() | |||
| torch.backends.cudnn.benchmark = True | |||
| if opt.load_ckpt is not None: | |||
| model_dict = model.state_dict() | |||
| #load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth') | |||
| load_ckpt_path = os.path.join('./checkpoints/exp-colondb/', 'ck_'+ str(opt.load_ckpt) + '.pth') | |||
| print(load_ckpt_path) | |||
| assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' | |||
| print('Loading checkpoint......') | |||
| checkpoint = torch.load(load_ckpt_path) | |||
| new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()} | |||
| model_dict.update(new_dict) | |||
| model.load_state_dict(model_dict) | |||
| print('Done') | |||
| return model | |||
| @@ -0,0 +1,42 @@ | |||
| # Adopted from PRAnet | |||
| import os | |||
| from PIL import Image | |||
| import torch.utils.data as data | |||
| import torchvision.transforms as transforms | |||
| class test_dataset: | |||
| def __init__(self, image_root, gt_root, testsize): | |||
| self.testsize = testsize | |||
| self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |||
| self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] | |||
| self.images = sorted(self.images) | |||
| self.gts = sorted(self.gts) | |||
| self.transform = transforms.Compose([ | |||
| transforms.Resize((self.testsize, self.testsize)), | |||
| transforms.ToTensor() | |||
| #transforms.Normalize([0.485, 0.456, 0.406], | |||
| #[0.229, 0.224, 0.225]) | |||
| ]) | |||
| self.gt_transform = transforms.ToTensor() | |||
| self.size = len(self.images) | |||
| self.index = 0 | |||
| def load_data(self): | |||
| image = self.rgb_loader(self.images[self.index]) | |||
| image = self.transform(image).unsqueeze(0) | |||
| gt = self.binary_loader(self.gts[self.index]) | |||
| name = self.images[self.index].split('/')[-1] | |||
| if name.endswith('.jpg'): | |||
| name = name.split('.jpg')[0] + '.png' | |||
| self.index += 1 | |||
| return image, gt, name | |||
| def rgb_loader(self, path): | |||
| with open(path, 'rb') as f: | |||
| img = Image.open(f) | |||
| return img.convert('RGB') | |||
| def binary_loader(self, path): | |||
| with open(path, 'rb') as f: | |||
| img = Image.open(f) | |||
| return img.convert('L') | |||
| @@ -0,0 +1,80 @@ | |||
| import torch | |||
| """ | |||
| The evaluation implementation refers to the following paper: | |||
| "Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation" | |||
| https://github.com/Yuqi-cuhk/Polyp-Seg | |||
| """ | |||
| def evaluate(pred, gt, th): | |||
| if isinstance(pred, (list, tuple)): | |||
| pred = pred[0] | |||
| pred_binary = (pred >= th).float() | |||
| pred_binary_inverse = (pred_binary == 0).float() | |||
| gt_binary = (gt >= th).float() | |||
| gt_binary_inverse = (gt_binary == 0).float() | |||
| TP = pred_binary.mul(gt_binary).sum() | |||
| FP = pred_binary.mul(gt_binary_inverse).sum() | |||
| TN = pred_binary_inverse.mul(gt_binary_inverse).sum() | |||
| FN = pred_binary_inverse.mul(gt_binary).sum() | |||
| if TP.item() == 0: | |||
| # print('TP=0 now!') | |||
| # print('Epoch: {}'.format(epoch)) | |||
| # print('i_batch: {}'.format(i_batch)) | |||
| TP = torch.Tensor([1]).cuda() | |||
| # recall | |||
| Recall = TP / (TP + FN) | |||
| # Specificity or true negative rate | |||
| Specificity = TN / (TN + FP) | |||
| # Precision or positive predictive value | |||
| Precision = TP / (TP + FP) | |||
| # F1 score = Dice | |||
| F1 = 2 * Precision * Recall / (Precision + Recall) | |||
| # F2 score | |||
| F2 = 5 * Precision * Recall / (4 * Precision + Recall) | |||
| # Overall accuracy | |||
| ACC_overall = (TP + TN) / (TP + FP + FN + TN) | |||
| # IoU for poly | |||
| IoU_poly = TP / (TP + FP + FN) | |||
| # IoU for background | |||
| IoU_bg = TN / (TN + FP + FN) | |||
| # mean IoU | |||
| IoU_mean = (IoU_poly + IoU_bg) / 2.0 | |||
| #Dice | |||
| Dice = (2 * TP)/(2*TP + FN + FP) | |||
| return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, Dice | |||
| class Metrics(object): | |||
| def __init__(self, metrics_list): | |||
| self.metrics = {} | |||
| for metric in metrics_list: | |||
| self.metrics[metric] = 0 | |||
| def update(self, **kwargs): | |||
| for k, v in kwargs.items(): | |||
| assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k) | |||
| if isinstance(v, torch.Tensor): | |||
| v = v.item() | |||
| self.metrics[k] += v | |||
| def mean(self, total): | |||
| mean_metrics = {} | |||
| for k, v in self.metrics.items(): | |||
| mean_metrics[k] = v / total | |||
| return mean_metrics | |||