Browse Source

Add files via upload

main
rucv 3 years ago
parent
commit
0b3e20d9db
No account linked to committer's email address
3 changed files with 149 additions and 0 deletions
  1. 27
    0
      utils/comm.py
  2. 42
    0
      utils/data_loader.py
  3. 80
    0
      utils/metrics.py

+ 27
- 0
utils/comm.py View File

#Adopted from ACSNet
import models
import torch
import os


def generate_model(opt):
model = getattr(models, opt.model)(opt.nclasses)
if opt.use_gpu:
model.cuda()
torch.backends.cudnn.benchmark = True

if opt.load_ckpt is not None:
model_dict = model.state_dict()
#load_ckpt_path = os.path.join('./checkpoints/exp'+str(opt.expID)+'/', opt.load_ckpt + '.pth')
load_ckpt_path = os.path.join('./checkpoints/exp-colondb/', 'ck_'+ str(opt.load_ckpt) + '.pth')
print(load_ckpt_path)
assert os.path.isfile(load_ckpt_path), 'No checkpoint found.'
print('Loading checkpoint......')
checkpoint = torch.load(load_ckpt_path)
new_dict = {k : v for k, v in checkpoint.items() if k in model_dict.keys()}
model_dict.update(new_dict)
model.load_state_dict(model_dict)

print('Done')

return model

+ 42
- 0
utils/data_loader.py View File

# Adopted from PRAnet
import os
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms

class test_dataset:
def __init__(self, image_root, gt_root, testsize):
self.testsize = testsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.transform = transforms.Compose([
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor()
#transforms.Normalize([0.485, 0.456, 0.406],
#[0.229, 0.224, 0.225])
])
self.gt_transform = transforms.ToTensor()
self.size = len(self.images)
self.index = 0

def load_data(self):
image = self.rgb_loader(self.images[self.index])
image = self.transform(image).unsqueeze(0)
gt = self.binary_loader(self.gts[self.index])
name = self.images[self.index].split('/')[-1]
if name.endswith('.jpg'):
name = name.split('.jpg')[0] + '.png'
self.index += 1
return image, gt, name

def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')

+ 80
- 0
utils/metrics.py View File

import torch

"""
The evaluation implementation refers to the following paper:
"Selective Feature Aggregation Network with Area-Boundary Constraints for Polyp Segmentation"
https://github.com/Yuqi-cuhk/Polyp-Seg
"""
def evaluate(pred, gt, th):
if isinstance(pred, (list, tuple)):
pred = pred[0]

pred_binary = (pred >= th).float()
pred_binary_inverse = (pred_binary == 0).float()

gt_binary = (gt >= th).float()
gt_binary_inverse = (gt_binary == 0).float()

TP = pred_binary.mul(gt_binary).sum()
FP = pred_binary.mul(gt_binary_inverse).sum()
TN = pred_binary_inverse.mul(gt_binary_inverse).sum()
FN = pred_binary_inverse.mul(gt_binary).sum()

if TP.item() == 0:
# print('TP=0 now!')
# print('Epoch: {}'.format(epoch))
# print('i_batch: {}'.format(i_batch))
TP = torch.Tensor([1]).cuda()

# recall
Recall = TP / (TP + FN)

# Specificity or true negative rate
Specificity = TN / (TN + FP)

# Precision or positive predictive value
Precision = TP / (TP + FP)

# F1 score = Dice
F1 = 2 * Precision * Recall / (Precision + Recall)

# F2 score
F2 = 5 * Precision * Recall / (4 * Precision + Recall)

# Overall accuracy
ACC_overall = (TP + TN) / (TP + FP + FN + TN)

# IoU for poly
IoU_poly = TP / (TP + FP + FN)

# IoU for background
IoU_bg = TN / (TN + FP + FN)

# mean IoU
IoU_mean = (IoU_poly + IoU_bg) / 2.0

#Dice
Dice = (2 * TP)/(2*TP + FN + FP)
return Recall, Specificity, Precision, F1, F2, ACC_overall, IoU_poly, IoU_bg, IoU_mean, Dice


class Metrics(object):
def __init__(self, metrics_list):
self.metrics = {}
for metric in metrics_list:
self.metrics[metric] = 0

def update(self, **kwargs):
for k, v in kwargs.items():
assert (k in self.metrics.keys()), "The k {} is not in metrics".format(k)
if isinstance(v, torch.Tensor):
v = v.item()

self.metrics[k] += v

def mean(self, total):
mean_metrics = {}
for k, v in self.metrics.items():
mean_metrics[k] = v / total
return mean_metrics

Loading…
Cancel
Save