12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import os
- import glob
- import torch
- from monai.transforms import (
- Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
- NormalizeIntensityd, ScaleIntensityd, EnsureTyped,
- Activationsd, AsDiscreted, Invertd, SaveImaged
- )
- from monai.data import Dataset, DataLoader
- from monai.networks.nets import SegResNet
- from monai.inferers import SlidingWindowInferer
-
- input_dir = "" #NIFTI Image
- output_dir = ""
- model_path = "models/model.pt"
- os.makedirs(output_dir, exist_ok=True)
-
- image_files = sorted(glob.glob(os.path.join(input_dir, "*.nii.gz")))
- data_dicts = [{"image": f} for f in image_files]
-
- pre_transforms = Compose([
- LoadImaged(keys=["image"]),
- EnsureChannelFirstd(keys=["image"]),
- Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
- Orientationd(keys=["image"], axcodes="RAS"),
- NormalizeIntensityd(keys=["image"], nonzero=True),
- ScaleIntensityd(keys=["image"], minv=-1.0, maxv=1.0),
- EnsureTyped(keys=["image"]),
- ])
-
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
-
- model = SegResNet(
- spatial_dims=3,
- in_channels=1,
- out_channels=105,
- init_filters=32,
- blocks_down=[1, 2, 2, 4],
- blocks_up=[1, 1, 1],
- dropout_prob=0.2
- ).to(device)
- model.load_state_dict(torch.load(model_path, map_location=device))
- model.eval()
-
- dataset = Dataset(data=data_dicts, transform=pre_transforms)
- dataloader = DataLoader(dataset, batch_size=1, num_workers=2)
-
- inferer = SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=1, overlap=0.25)
-
- post_transforms = Compose([
- Activationsd(keys="pred", softmax=True),
- AsDiscreted(keys="pred", argmax=True),
- Invertd(
- keys="pred",
- transform=pre_transforms,
- orig_keys="image",
- meta_key_postfix="meta_dict",
- nearest_interp=True,
- to_tensor=True
- ),
- SaveImaged(
- keys="pred",
- meta_keys="pred_meta_dict",
- output_dir=output_dir,
- output_postfix="seg",
- separate_folder=False,
- resample=True
- )
- ])
-
- with torch.no_grad():
- for batch_data in dataloader:
- batch_data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_data.items()}
- outputs = inferer(inputs=batch_data["image"], network=model)
- batch_data["pred"] = outputs
- batch_data = post_transforms(batch_data)
- print("✅ Done:", batch_data["image_meta_dict"]["filename_or_obj"][0])
|