@@ -5,24 +5,26 @@ import os | |||
parse = argparse.ArgumentParser(description='PyTorch Polyp Segmentation') | |||
"-------------------data option--------------------------" | |||
parse.add_argument('--root', type=str, default='/scratch/krushi1992/MICCAI2021/new-med-up/CVC-frame/') | |||
parse.add_argument('--dataset', type=str, default='EndoScene') | |||
parse.add_argument('--train_data_dir', type=str, default='train') | |||
parse.add_argument('--valid_data_dir', type=str, default='valid') | |||
parse.add_argument('--test_data_dir', type=str, default='test') | |||
# parse.add_argument('--root', type=str, default='/media/external_10TB/10TB/pourmand/CVC-EndoSceneStill') | |||
# parse.add_argument('--dataset', type=str, default='EndoScene') | |||
parse.add_argument('--root',default='/media/external_10TB/10TB/pourmand/Kvasir-SEG/Kvasir-SEG') | |||
parse.add_argument('--dataset',type=str,default='kvasir_SEG') | |||
parse.add_argument('--train_data_dir', type=str, default='') | |||
parse.add_argument('--valid_data_dir', type=str, default='') | |||
parse.add_argument('--test_data_dir', type=str, default='') | |||
"-------------------training option-----------------------" | |||
parse.add_argument('--mode', type=str, default='train') | |||
parse.add_argument('--nEpoch', type=int, default=200) | |||
parse.add_argument('--nEpoch', type=int, default=10) | |||
parse.add_argument('--batch_size', type=float, default=4) | |||
parse.add_argument('--num_workers', type=int, default=0) | |||
parse.add_argument('--use_gpu', type=bool, default=True) | |||
parse.add_argument('--load_ckpt', type=str, default=None) | |||
parse.add_argument('--model', type=str, default='EUNet') | |||
parse.add_argument('--expID', type=int, default=1) | |||
parse.add_argument('--ckpt_period', type=int, default=50) | |||
parse.add_argument('--ckpt_period', type=int, default=0) | |||
parse.add_argument('--weight_const', type=float, default=0.3) | |||
"-------------------optimizer option-----------------------" |
@@ -43,6 +43,7 @@ def valid(model, valid_dataloader, total_batch): | |||
def train(): | |||
file_name = 'results.txt' | |||
model = generate_model(opt) | |||
#model = nn.DataParallel(model) | |||
@@ -101,6 +102,14 @@ def train(): | |||
metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], | |||
metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])) | |||
with open(file_name,'a') as f: | |||
f.write('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,' | |||
' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' | |||
% (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'], | |||
metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'], | |||
metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])+'\n') | |||
if ((epoch + 1) % opt.ckpt_period == 0): | |||
torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1)) | |||
@@ -76,13 +76,13 @@ def DeepSupervisionLoss(pred, gt): | |||
criterion = BceDiceLoss() | |||
loss0 = criterion(d0, gt) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) | |||
loss1 = criterion(d1, gt) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) | |||
loss2 = criterion(d2, gt) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) | |||
loss3 = criterion(d3, gt) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) | |||
gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True) | |||
loss4 = criterion(d4, gt) | |||
return loss0 + loss1 + loss2 + loss3 + loss4 |