@@ -1,27 +0,0 @@ | |||
#Adopted from ACSNet | |||
import models | |||
import torch | |||
import os | |||
def generate_model(opt): | |||
model = getattr(models, opt.model)(opt.nclasses) | |||
if opt.use_gpu: | |||
model.cuda() | |||
torch.backends.cudnn.benchmark = True | |||
if opt.load_ckpt is not None: | |||
model_dict = model.state_dict() | |||
#load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth') | |||
load_ckpt_path = os.path.join('./checkpoints/exp-colondb/', 'ck_'+ str(opt.load_ckpt) + '.pth') | |||
print(load_ckpt_path) | |||
assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' | |||
print('Loading checkpoint......') | |||
checkpoint = torch.load(load_ckpt_path) | |||
new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()} | |||
model_dict.update(new_dict) | |||
model.load_state_dict(model_dict) | |||
print('Done') | |||
return model |
@@ -1,42 +0,0 @@ | |||
# Adopted from PRAnet | |||
import os | |||
from PIL import Image | |||
import torch.utils.data as data | |||
import torchvision.transforms as transforms | |||
class test_dataset: | |||
def __init__(self, image_root, gt_root, testsize): | |||
self.testsize = testsize | |||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] | |||
self.images = sorted(self.images) | |||
self.gts = sorted(self.gts) | |||
self.transform = transforms.Compose([ | |||
transforms.Resize((self.testsize, self.testsize)), | |||
transforms.ToTensor() | |||
#transforms.Normalize([0.485, 0.456, 0.406], | |||
#[0.229, 0.224, 0.225]) | |||
]) | |||
self.gt_transform = transforms.ToTensor() | |||
self.size = len(self.images) | |||
self.index = 0 | |||
def load_data(self): | |||
image = self.rgb_loader(self.images[self.index]) | |||
image = self.transform(image).unsqueeze(0) | |||
gt = self.binary_loader(self.gts[self.index]) | |||
name = self.images[self.index].split('/')[-1] | |||
if name.endswith('.jpg'): | |||
name = name.split('.jpg')[0] + '.png' | |||
self.index += 1 | |||
return image, gt, name | |||
def rgb_loader(self, path): | |||
with open(path, 'rb') as f: | |||
img = Image.open(f) | |||
return img.convert('RGB') | |||
def binary_loader(self, path): | |||
with open(path, 'rb') as f: | |||
img = Image.open(f) | |||
return img.convert('L') |
@@ -1,80 +0,0 @@ | |||
import torch | |||
""" | |||
The evaluation implementation refers to the following paper: | |||
"Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation" | |||
https://github.com/Yuqi-cuhk/Polyp-Seg | |||
""" | |||
def evaluate(pred, gt, th): | |||
if isinstance(pred, (list, tuple)): | |||
pred = pred[0] | |||
pred_binary = (pred >= th).float() | |||
pred_binary_inverse = (pred_binary == 0).float() | |||
gt_binary = (gt >= th).float() | |||
gt_binary_inverse = (gt_binary == 0).float() | |||
TP = pred_binary.mul(gt_binary).sum() | |||
FP = pred_binary.mul(gt_binary_inverse).sum() | |||
TN = pred_binary_inverse.mul(gt_binary_inverse).sum() | |||
FN = pred_binary_inverse.mul(gt_binary).sum() | |||
if TP.item() == 0: | |||
# print('TP=0 now!') | |||
# print('Epoch: {}'.format(epoch)) | |||
# print('i_batch: {}'.format(i_batch)) | |||
TP = torch.Tensor([1]).cuda() | |||
# recall | |||
Recall = TP / (TP + FN) | |||
# Specificity or true negative rate | |||
Specificity = TN / (TN + FP) | |||
# Precision or positive predictive value | |||
Precision = TP / (TP + FP) | |||
# F1 score = Dice | |||
F1 = 2 * Precision * Recall / (Precision + Recall) | |||
# F2 score | |||
F2 = 5 * Precision * Recall / (4 * Precision + Recall) | |||
# Overall accuracy | |||
ACC_overall = (TP + TN) / (TP + FP + FN + TN) | |||
# IoU for poly | |||
IoU_poly = TP / (TP + FP + FN) | |||
# IoU for background | |||
IoU_bg = TN / (TN + FP + FN) | |||
# mean IoU | |||
IoU_mean = (IoU_poly + IoU_bg) / 2.0 | |||
#Dice | |||
Dice = (2 * TP)/(2*TP + FN + FP) | |||
return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, Dice | |||
class Metrics(object): | |||
def __init__(self, metrics_list): | |||
self.metrics = {} | |||
for metric in metrics_list: | |||
self.metrics[metric] = 0 | |||
def update(self, **kwargs): | |||
for k, v in kwargs.items(): | |||
assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k) | |||
if isinstance(v, torch.Tensor): | |||
v = v.item() | |||
self.metrics[k] += v | |||
def mean(self, total): | |||
mean_metrics = {} | |||
for k, v in self.metrics.items(): | |||
mean_metrics[k] = v / total | |||
return mean_metrics |