|
|
|
|
|
|
|
|
#Adopted from the ACSNet |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
import datasets |
|
|
|
|
|
from utils.metrics import evaluate |
|
|
|
|
|
from opt import opt |
|
|
|
|
|
from utils.comm import generate_model |
|
|
|
|
|
from utils.loss import DeepSupervisionLoss, BceDiceLoss |
|
|
|
|
|
from utils.metrics import Metrics |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def valid(model, valid_dataloader, total_batch): |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
# Metrics_logger initialization |
|
|
|
|
|
metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2', |
|
|
|
|
|
'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean']) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
bar = tqdm(enumerate(valid_dataloader), total=total_batch) |
|
|
|
|
|
for i, data in bar: |
|
|
|
|
|
img, gt = data['image'], data['label'] |
|
|
|
|
|
|
|
|
|
|
|
if opt.use_gpu: |
|
|
|
|
|
img = img.cuda() |
|
|
|
|
|
gt = gt.cuda() |
|
|
|
|
|
|
|
|
|
|
|
output = model(img) |
|
|
|
|
|
_recall, _specificity, _precision, _F1, _F2, \ |
|
|
|
|
|
_ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt, 0.5) |
|
|
|
|
|
|
|
|
|
|
|
metrics.update(recall= _recall, specificity= _specificity, precision= _precision, |
|
|
|
|
|
F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly, |
|
|
|
|
|
IoU_bg= _IoU_bg, IoU_mean= _IoU_mean |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
metrics_result = metrics.mean(total_batch) |
|
|
|
|
|
|
|
|
|
|
|
return metrics_result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
|
|
|
|
|
|
|
|
model = generate_model(opt) |
|
|
|
|
|
#model = nn.DataParallel(model) |
|
|
|
|
|
|
|
|
|
|
|
# load data |
|
|
|
|
|
train_data = getattr(datasets, opt.dataset)(opt.root, opt.train_data_dir, mode='train') |
|
|
|
|
|
train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers) |
|
|
|
|
|
valid_data = getattr(datasets, opt.dataset)(opt.root, opt.valid_data_dir, mode='valid') |
|
|
|
|
|
valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=opt.num_workers) |
|
|
|
|
|
val_total_batch = int(len(valid_data) / 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# load optimizer and scheduler |
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.mt, weight_decay=opt.weight_decay) |
|
|
|
|
|
|
|
|
|
|
|
lr_lambda = lambda epoch: 1.0 - pow((epoch / opt.nEpoch), opt.power) |
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
|
|
|
|
|
|
# train |
|
|
|
|
|
print('Start training') |
|
|
|
|
|
print('---------------------------------\n') |
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(opt.nEpoch): |
|
|
|
|
|
print('------ Epoch', epoch + 1) |
|
|
|
|
|
model.train() |
|
|
|
|
|
total_batch = int(len(train_data) / opt.batch_size) |
|
|
|
|
|
bar = tqdm(enumerate(train_dataloader), total=total_batch) |
|
|
|
|
|
|
|
|
|
|
|
for i, data in bar: |
|
|
|
|
|
img = data['image'] |
|
|
|
|
|
gt = data['label'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.use_gpu: |
|
|
|
|
|
img = img.cuda() |
|
|
|
|
|
gt = gt.cuda() |
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
output = model(img) |
|
|
|
|
|
|
|
|
|
|
|
#loss = BceDiceLoss()(output, gt) |
|
|
|
|
|
loss = DeepSupervisionLoss(output, gt) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
bar.set_postfix_str('loss: %.5s' % loss.item()) |
|
|
|
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
|
|
|
metrics_result = valid(model, valid_dataloader, val_total_batch) |
|
|
|
|
|
|
|
|
|
|
|
print("Valid Result:") |
|
|
|
|
|
print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,' |
|
|
|
|
|
' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' |
|
|
|
|
|
% (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], |
|
|
|
|
|
metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], |
|
|
|
|
|
metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])) |
|
|
|
|
|
|
|
|
|
|
|
if ((epoch + 1) % opt.ckpt_period == 0): |
|
|
|
|
|
torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
if opt.mode == 'train': |
|
|
|
|
|
print('---PolpySeg Train---') |
|
|
|
|
|
train() |
|
|
|
|
|
|
|
|
|
|
|
print('Done') |
|
|
|