You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 4.3KB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. from torch.utils.data import DataLoader
  3. from torch.optim.lr_scheduler import LambdaLR
  4. from tqdm import tqdm
  5. import datasets
  6. from utils.metrics import evaluate
  7. from opt import opt
  8. from utils.comm import generate_model
  9. from utils.loss import DeepSupervisionLoss, BceDiceLoss
  10. from utils.metrics import Metrics
  11. import torch.nn as nn
  12. def valid(model, valid_dataloader, total_batch):
  13. model.eval()
  14. # Metrics_logger initialization
  15. metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2',
  16. 'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean'])
  17. with torch.no_grad():
  18. bar = tqdm(enumerate(valid_dataloader), total=total_batch)
  19. for i, data in bar:
  20. img, gt = data['image'], data['label']
  21. if opt.use_gpu:
  22. img = img.cuda()
  23. gt = gt.cuda()
  24. output = model(img)
  25. _recall, _specificity, _precision, _F1, _F2, \
  26. _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt, 0.5)
  27. metrics.update(recall= _recall, specificity= _specificity, precision= _precision,
  28. F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly,
  29. IoU_bg= _IoU_bg, IoU_mean= _IoU_mean
  30. )
  31. metrics_result = metrics.mean(total_batch)
  32. return metrics_result
  33. def train():
  34. file_name = 'results.txt'
  35. model = generate_model(opt)
  36. #model = nn.DataParallel(model)
  37. # load data
  38. train_data = getattr(datasets, opt.dataset)(opt.root, opt.train_data_dir, mode='train')
  39. train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
  40. valid_data = getattr(datasets, opt.dataset)(opt.root, opt.valid_data_dir, mode='test')
  41. valid_dataloader = DataLoader(valid_data, batch_size=1, shuffle=False, num_workers=opt.num_workers)
  42. val_total_batch = int(len(valid_data) / 1)
  43. # load optimizer and scheduler
  44. optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.mt, weight_decay=opt.weight_decay)
  45. lr_lambda = lambda epoch: 1.0 - pow((epoch / opt.nEpoch), opt.power)
  46. scheduler = LambdaLR(optimizer, lr_lambda)
  47. # train
  48. print('Start training')
  49. print('---------------------------------\n')
  50. for epoch in range(opt.nEpoch):
  51. print('------ Epoch', epoch + 1)
  52. model.train()
  53. total_batch = int(len(train_data) / opt.batch_size)
  54. bar = tqdm(enumerate(train_dataloader), total=total_batch)
  55. for i, data in bar:
  56. img = data['image']
  57. gt = data['label']
  58. if opt.use_gpu:
  59. img = img.cuda()
  60. gt = gt.cuda()
  61. optimizer.zero_grad()
  62. output = model(img)
  63. #loss = BceDiceLoss()(output, gt)
  64. loss = DeepSupervisionLoss(output, gt)
  65. loss.backward()
  66. optimizer.step()
  67. bar.set_postfix_str('loss: %.5s' % loss.item())
  68. scheduler.step()
  69. metrics_result = valid(model, valid_dataloader, val_total_batch)
  70. print("Valid Result:")
  71. print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,'
  72. ' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f'
  73. % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'],
  74. metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'],
  75. metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean']))
  76. with open(file_name,'a') as f:
  77. f.write('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f,'
  78. ' F2: %.4f, ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f'
  79. % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'],
  80. metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'],
  81. metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'])+'\n')
  82. if ((epoch + 1) % opt.ckpt_period == 0):
  83. torch.save(model.state_dict(), './checkpoints/exp' + str(opt.expID)+"/ck_{}.pth".format(epoch + 1))
  84. if __name__ == '__main__':
  85. if opt.mode == 'train':
  86. print('---PolpySeg Train---')
  87. train()
  88. print('Done')