| #Adopted from ACSNet | |||||
| import models | |||||
| import torch | |||||
| import os | |||||
| def generate_model(opt): | |||||
| model = getattr(models, opt.model)(opt.nclasses) | |||||
| if opt.use_gpu: | |||||
| model.cuda() | |||||
| torch.backends.cudnn.benchmark = True | |||||
| if opt.load_ckpt is not None: | |||||
| model_dict = model.state_dict() | |||||
| #load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth') | |||||
| load_ckpt_path = os.path.join('./checkpoints/exp-colondb/', 'ck_'+ str(opt.load_ckpt) + '.pth') | |||||
| print(load_ckpt_path) | |||||
| assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' | |||||
| print('Loading checkpoint......') | |||||
| checkpoint = torch.load(load_ckpt_path) | |||||
| new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()} | |||||
| model_dict.update(new_dict) | |||||
| model.load_state_dict(model_dict) | |||||
| print('Done') | |||||
| return model |
| # Adopted from PRAnet | |||||
| import os | |||||
| from PIL import Image | |||||
| import torch.utils.data as data | |||||
| import torchvision.transforms as transforms | |||||
| class test_dataset: | |||||
| def __init__(self, image_root, gt_root, testsize): | |||||
| self.testsize = testsize | |||||
| self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |||||
| self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] | |||||
| self.images = sorted(self.images) | |||||
| self.gts = sorted(self.gts) | |||||
| self.transform = transforms.Compose([ | |||||
| transforms.Resize((self.testsize, self.testsize)), | |||||
| transforms.ToTensor() | |||||
| #transforms.Normalize([0.485, 0.456, 0.406], | |||||
| #[0.229, 0.224, 0.225]) | |||||
| ]) | |||||
| self.gt_transform = transforms.ToTensor() | |||||
| self.size = len(self.images) | |||||
| self.index = 0 | |||||
| def load_data(self): | |||||
| image = self.rgb_loader(self.images[self.index]) | |||||
| image = self.transform(image).unsqueeze(0) | |||||
| gt = self.binary_loader(self.gts[self.index]) | |||||
| name = self.images[self.index].split('/')[-1] | |||||
| if name.endswith('.jpg'): | |||||
| name = name.split('.jpg')[0] + '.png' | |||||
| self.index += 1 | |||||
| return image, gt, name | |||||
| def rgb_loader(self, path): | |||||
| with open(path, 'rb') as f: | |||||
| img = Image.open(f) | |||||
| return img.convert('RGB') | |||||
| def binary_loader(self, path): | |||||
| with open(path, 'rb') as f: | |||||
| img = Image.open(f) | |||||
| return img.convert('L') |
| import torch | |||||
| """ | |||||
| The evaluation implementation refers to the following paper: | |||||
| "Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation" | |||||
| https://github.com/Yuqi-cuhk/Polyp-Seg | |||||
| """ | |||||
| def evaluate(pred, gt, th): | |||||
| if isinstance(pred, (list, tuple)): | |||||
| pred = pred[0] | |||||
| pred_binary = (pred >= th).float() | |||||
| pred_binary_inverse = (pred_binary == 0).float() | |||||
| gt_binary = (gt >= th).float() | |||||
| gt_binary_inverse = (gt_binary == 0).float() | |||||
| TP = pred_binary.mul(gt_binary).sum() | |||||
| FP = pred_binary.mul(gt_binary_inverse).sum() | |||||
| TN = pred_binary_inverse.mul(gt_binary_inverse).sum() | |||||
| FN = pred_binary_inverse.mul(gt_binary).sum() | |||||
| if TP.item() == 0: | |||||
| # print('TP=0 now!') | |||||
| # print('Epoch: {}'.format(epoch)) | |||||
| # print('i_batch: {}'.format(i_batch)) | |||||
| TP = torch.Tensor([1]).cuda() | |||||
| # recall | |||||
| Recall = TP / (TP + FN) | |||||
| # Specificity or true negative rate | |||||
| Specificity = TN / (TN + FP) | |||||
| # Precision or positive predictive value | |||||
| Precision = TP / (TP + FP) | |||||
| # F1 score = Dice | |||||
| F1 = 2 * Precision * Recall / (Precision + Recall) | |||||
| # F2 score | |||||
| F2 = 5 * Precision * Recall / (4 * Precision + Recall) | |||||
| # Overall accuracy | |||||
| ACC_overall = (TP + TN) / (TP + FP + FN + TN) | |||||
| # IoU for poly | |||||
| IoU_poly = TP / (TP + FP + FN) | |||||
| # IoU for background | |||||
| IoU_bg = TN / (TN + FP + FN) | |||||
| # mean IoU | |||||
| IoU_mean = (IoU_poly + IoU_bg) / 2.0 | |||||
| #Dice | |||||
| Dice = (2 * TP)/(2*TP + FN + FP) | |||||
| return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, Dice | |||||
| class Metrics(object): | |||||
| def __init__(self, metrics_list): | |||||
| self.metrics = {} | |||||
| for metric in metrics_list: | |||||
| self.metrics[metric] = 0 | |||||
| def update(self, **kwargs): | |||||
| for k, v in kwargs.items(): | |||||
| assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k) | |||||
| if isinstance(v, torch.Tensor): | |||||
| v = v.item() | |||||
| self.metrics[k] += v | |||||
| def mean(self, total): | |||||
| mean_metrics = {} | |||||
| for k, v in self.metrics.items(): | |||||
| mean_metrics[k] = v / total | |||||
| return mean_metrics |