| @@ -0,0 +1,116 @@ | |||
| #Adopted from the ACSNet | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| from torch.optim.lr_scheduler import LambdaLR | |||
| from tqdm import tqdm | |||
| import datasets | |||
| from utils.metrics import evaluate | |||
| from opt import opt | |||
| from utils.comm import generate_model | |||
| from utils.loss import DeepSupervisionLoss, BceDiceLoss | |||
| from utils.metrics import Metrics | |||
| import torch.nn as nn | |||
| def valid(model, valid_dataloader, total_batch): | |||
| model.eval() | |||
| # Metrics_logger initialization | |||
| metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2', | |||
| 'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean']) | |||
| with torch.no_grad(): | |||
| bar = tqdm(enumerate(valid_dataloader), total=total_batch) | |||
| for i, data in bar: | |||
| img, gt = data['image'], data['label'] | |||
| if opt.use_gpu: | |||
| img = img.cuda() | |||
| gt = gt.cuda() | |||
| output = model(img) | |||
| _recall, _specificity, _precision, _F1, _F2, \ | |||
| _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt, 0.5) | |||
| metrics.update(recall= _recall, specificity= _specificity, precision= _precision, | |||
| F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly, | |||
| IoU_bg= _IoU_bg, IoU_mean= _IoU_mean | |||
| ) | |||
| metrics_result = metrics.mean(total_batch) | |||
| return metrics_result | |||
| def train(): | |||
| model = generate_model(opt) | |||
| #model = nn.DataParallel(model) | |||
| # load data | |||
| train_data = getattr(datasets, opt.dataset)(opt.root, opt.train_data_dir, mode='train') | |||
| train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers) | |||
| valid_data = getattr(datasets, opt.dataset)(opt.root, opt.valid_data_dir, mode='valid') | |||
| valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=opt.num_workers) | |||
| val_total_batch = int(len(valid_data) / 1) | |||
| # load optimizer and scheduler | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.mt, weight_decay=opt.weight_decay) | |||
| lr_lambda = lambda epoch: 1.0 - pow((epoch / opt.nEpoch), opt.power) | |||
| scheduler = LambdaLR(optimizer, lr_lambda) | |||
| # train | |||
| print('Start training') | |||
| print('---------------------------------\n') | |||
| for epoch in range(opt.nEpoch): | |||
| print('------ Epoch', epoch + 1) | |||
| model.train() | |||
| total_batch = int(len(train_data) / opt.batch_size) | |||
| bar = tqdm(enumerate(train_dataloader), total=total_batch) | |||
| for i, data in bar: | |||
| img = data['image'] | |||
| gt = data['label'] | |||
| if opt.use_gpu: | |||
| img = img.cuda() | |||
| gt = gt.cuda() | |||
| optimizer.zero_grad() | |||
| output = model(img) | |||
| #loss = BceDiceLoss()(output, gt) | |||
| loss = DeepSupervisionLoss(output, gt) | |||
| loss.backward() | |||
| optimizer.step() | |||
| bar.set_postfix_str('loss: %.5s' % loss.item()) | |||
| scheduler.step() | |||
| metrics_result = valid(model, valid_dataloader, val_total_batch) | |||
| print("Valid Result:") | |||
| print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,' | |||
| ' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' | |||
| % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], | |||
| metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], | |||
| metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])) | |||
| if ((epoch + 1) % opt.ckpt_period == 0): | |||
| torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1)) | |||
| if __name__ == '__main__': | |||
| if opt.mode == 'train': | |||
| print('---PolpySeg Train---') | |||
| train() | |||
| print('Done') | |||