import os
import glob
import torch
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd,
    NormalizeIntensityd, ScaleIntensityd, EnsureTyped,
    Activationsd, AsDiscreted, Invertd, SaveImaged
)
from monai.data import Dataset, DataLoader
from monai.networks.nets import SegResNet
from monai.inferers import SlidingWindowInferer

input_dir = "" #NIFTI Image
output_dir = ""
model_path = "models/model.pt"
os.makedirs(output_dir, exist_ok=True)

image_files = sorted(glob.glob(os.path.join(input_dir, "*.nii.gz")))
data_dicts = [{"image": f} for f in image_files]

pre_transforms = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear"),
    Orientationd(keys=["image"], axcodes="RAS"),
    NormalizeIntensityd(keys=["image"], nonzero=True),
    ScaleIntensityd(keys=["image"], minv=-1.0, maxv=1.0),
    EnsureTyped(keys=["image"]),
])


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


model = SegResNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=105,  
    init_filters=32,
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    dropout_prob=0.2
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

dataset = Dataset(data=data_dicts, transform=pre_transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=2)

inferer = SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=1, overlap=0.25)

post_transforms = Compose([
    Activationsd(keys="pred", softmax=True),
    AsDiscreted(keys="pred", argmax=True),
    Invertd(
        keys="pred",
        transform=pre_transforms,
        orig_keys="image",
        meta_key_postfix="meta_dict",
        nearest_interp=True,
        to_tensor=True
    ),
    SaveImaged(
        keys="pred",
        meta_keys="pred_meta_dict",
        output_dir=output_dir,
        output_postfix="seg",
        separate_folder=False,
        resample=True  
    )
])

with torch.no_grad():
    for batch_data in dataloader:
        batch_data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch_data.items()}
        outputs = inferer(inputs=batch_data["image"], network=model)
        batch_data["pred"] = outputs
        batch_data = post_transforms(batch_data)
        print("✅ Done:", batch_data["image_meta_dict"]["filename_or_obj"][0])