You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 2.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. from tqdm import tqdm
  3. from opt import opt
  4. from utils.metrics import evaluate
  5. import datasets
  6. from torch.utils.data import DataLoader
  7. from utils.comm import generate_model
  8. from utils.metrics import Metrics
  9. def test():
  10. print('loading data......')
  11. test_data = getattr(datasets, opt.dataset)(opt.root, opt.test_data_dir, mode='test')
  12. test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=opt.num_workers)
  13. total_batch = int(len(test_data) / 1)
  14. model = generate_model(opt)
  15. model.eval()
  16. # metrics_logger initialization
  17. metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2',
  18. 'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean', 'Dice'])
  19. with torch.no_grad():
  20. bar = tqdm(enumerate(test_dataloader), total=total_batch)
  21. for i, data in bar:
  22. img, gt = data['image'], data['label']
  23. if opt.use_gpu:
  24. img = img.cuda()
  25. gt = gt.cuda()
  26. output = model(img)
  27. _recall, _specificity, _precision, _F1, _F2, \
  28. _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean, _Dice = evaluate(output, gt)
  29. metrics.update(recall= _recall, specificity= _specificity, precision= _precision,
  30. F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly,
  31. IoU_bg= _IoU_bg, IoU_mean= _IoU_mean, Dice = _Dice
  32. )
  33. metrics_result = metrics.mean(total_batch)
  34. print("Test Result:")
  35. print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, '
  36. 'ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f'
  37. % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'],
  38. metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'],
  39. metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'], metric_result['Dice']))
  40. if __name__ == '__main__':
  41. if opt.mode == 'test':
  42. print('--- PolypSeg Test---')
  43. test()
  44. print('Done')