Organ-aware 3D lesion segmentation dataset and pipeline for abdominal CT analysis (ACM Multimedia 2025 candidate)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

segment.py 2.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import glob
  3. import torch
  4. from monai.transforms import (
  5. Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
  6. NormalizeIntensityd, ScaleIntensityd, EnsureTyped,
  7. Activationsd, AsDiscreted, Invertd, SaveImaged
  8. )
  9. from monai.data import Dataset, DataLoader
  10. from monai.networks.nets import SegResNet
  11. from monai.inferers import SlidingWindowInferer
  12. input_dir = "" #NIFTI Image
  13. output_dir = ""
  14. model_path = "models/model.pt"
  15. os.makedirs(output_dir, exist_ok=True)
  16. image_files = sorted(glob.glob(os.path.join(input_dir, "*.nii.gz")))
  17. data_dicts = [{"image": f} for f in image_files]
  18. pre_transforms = Compose([
  19. LoadImaged(keys=["image"]),
  20. EnsureChannelFirstd(keys=["image"]),
  21. Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
  22. Orientationd(keys=["image"], axcodes="RAS"),
  23. NormalizeIntensityd(keys=["image"], nonzero=True),
  24. ScaleIntensityd(keys=["image"], minv=-1.0, maxv=1.0),
  25. EnsureTyped(keys=["image"]),
  26. ])
  27. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  28. model = SegResNet(
  29. spatial_dims=3,
  30. in_channels=1,
  31. out_channels=105,
  32. init_filters=32,
  33. blocks_down=[1, 2, 2, 4],
  34. blocks_up=[1, 1, 1],
  35. dropout_prob=0.2
  36. ).to(device)
  37. model.load_state_dict(torch.load(model_path, map_location=device))
  38. model.eval()
  39. dataset = Dataset(data=data_dicts, transform=pre_transforms)
  40. dataloader = DataLoader(dataset, batch_size=1, num_workers=2)
  41. inferer = SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=1, overlap=0.25)
  42. post_transforms = Compose([
  43. Activationsd(keys="pred", softmax=True),
  44. AsDiscreted(keys="pred", argmax=True),
  45. Invertd(
  46. keys="pred",
  47. transform=pre_transforms,
  48. orig_keys="image",
  49. meta_key_postfix="meta_dict",
  50. nearest_interp=True,
  51. to_tensor=True
  52. ),
  53. SaveImaged(
  54. keys="pred",
  55. meta_keys="pred_meta_dict",
  56. output_dir=output_dir,
  57. output_postfix="seg",
  58. separate_folder=False,
  59. resample=True
  60. )
  61. ])
  62. with torch.no_grad():
  63. for batch_data in dataloader:
  64. batch_data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_data.items()}
  65. outputs = inferer(inputs=batch_data["image"], network=model)
  66. batch_data["pred"] = outputs
  67. batch_data = post_transforms(batch_data)
  68. print("✅ Done:", batch_data["image_meta_dict"]["filename_or_obj"][0])