import models | |||||
import torch | |||||
import os | |||||
#Adopted from the ACSNet | |||||
def generate_model(opt): | |||||
model = getattr(models, opt.model)(opt.nclasses) | |||||
if opt.use_gpu: | |||||
model.cuda() | |||||
torch.backends.cudnn.benchmark = True | |||||
if opt.load_ckpt is not None: | |||||
model_dict = model.state_dict() | |||||
#load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth') | |||||
load_ckpt_path = os.path.join('./checkpoints/exp-cvcframe/', 'ck_'+ str(opt.load_ckpt) + '.pth') | |||||
print(load_ckpt_path) | |||||
assert os.path.isfile(load_ckpt_path), 'No checkpoint found.' | |||||
print('Loading checkpoint......') | |||||
checkpoint = torch.load(load_ckpt_path) | |||||
new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()} | |||||
model_dict.update(new_dict) | |||||
model.load_state_dict(model_dict) | |||||
print('Done') | |||||
return model |
import os | |||||
from PIL import Image | |||||
import torch.utils.data as data | |||||
import torchvision.transforms as transforms | |||||
class test_dataset: | |||||
def __init__(self, image_root, gt_root, testsize): | |||||
self.testsize = testsize | |||||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |||||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] | |||||
self.images = sorted(self.images) | |||||
self.gts = sorted(self.gts) | |||||
self.transform = transforms.Compose([ | |||||
transforms.Resize((self.testsize, self.testsize)), | |||||
transforms.ToTensor() | |||||
#transforms.Normalize([0.485, 0.456, 0.406], | |||||
#[0.229, 0.224, 0.225]) | |||||
]) | |||||
self.gt_transform = transforms.ToTensor() | |||||
self.size = len(self.images) | |||||
self.index = 0 | |||||
def load_data(self): | |||||
image = self.rgb_loader(self.images[self.index]) | |||||
image = self.transform(image).unsqueeze(0) | |||||
gt = self.binary_loader(self.gts[self.index]) | |||||
name = self.images[self.index].split('/')[-1] | |||||
if name.endswith('.jpg'): | |||||
name = name.split('.jpg')[0] + '.png' | |||||
self.index += 1 | |||||
return image, gt, name | |||||
def rgb_loader(self, path): | |||||
with open(path, 'rb') as f: | |||||
img = Image.open(f) | |||||
return img.convert('RGB') | |||||
def binary_loader(self, path): | |||||
with open(path, 'rb') as f: | |||||
img = Image.open(f) | |||||
return img.convert('L') |
import torch | |||||
""" | |||||
The evaluation implementation refers to the following paper: | |||||
"Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation" | |||||
https://github.com/Yuqi-cuhk/Polyp-Seg | |||||
""" | |||||
def evaluate(pred, gt, th): | |||||
if isinstance(pred, (list, tuple)): | |||||
pred = pred[0] | |||||
pred_binary = (pred >= th).float() | |||||
pred_binary_inverse = (pred_binary == 0).float() | |||||
gt_binary = (gt >= th).float() | |||||
gt_binary_inverse = (gt_binary == 0).float() | |||||
TP = pred_binary.mul(gt_binary).sum() | |||||
FP = pred_binary.mul(gt_binary_inverse).sum() | |||||
TN = pred_binary_inverse.mul(gt_binary_inverse).sum() | |||||
FN = pred_binary_inverse.mul(gt_binary).sum() | |||||
if TP.item() == 0: | |||||
# print('TP=0 now!') | |||||
# print('Epoch: {}'.format(epoch)) | |||||
# print('i_batch: {}'.format(i_batch)) | |||||
TP = torch.Tensor([1]).cuda() | |||||
# recall | |||||
Recall = TP / (TP + FN) | |||||
# Specificity or true negative rate | |||||
Specificity = TN / (TN + FP) | |||||
# Precision or positive predictive value | |||||
Precision = TP / (TP + FP) | |||||
# F1 score = Dice | |||||
F1 = 2 * Precision * Recall / (Precision + Recall) | |||||
# F2 score | |||||
F2 = 5 * Precision * Recall / (4 * Precision + Recall) | |||||
# Overall accuracy | |||||
ACC_overall = (TP + TN) / (TP + FP + FN + TN) | |||||
# IoU for poly | |||||
IoU_poly = TP / (TP + FP + FN) | |||||
# IoU for background | |||||
IoU_bg = TN / (TN + FP + FN) | |||||
# mean IoU | |||||
IoU_mean = (IoU_poly + IoU_bg) / 2.0 | |||||
#Dice | |||||
Dice = (2 * TP)/(2*TP + FN + FP) | |||||
return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean | |||||
class Metrics(object): | |||||
def __init__(self, metrics_list): | |||||
self.metrics = {} | |||||
for metric in metrics_list: | |||||
self.metrics[metric] = 0 | |||||
def update(self, **kwargs): | |||||
for k, v in kwargs.items(): | |||||
assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k) | |||||
if isinstance(v, torch.Tensor): | |||||
v = v.item() | |||||
self.metrics[k] += v | |||||
def mean(self, total): | |||||
mean_metrics = {} | |||||
for k, v in self.metrics.items(): | |||||
mean_metrics[k] = v / total | |||||
return mean_metrics |