import numpy as np
import nibabel as nib
import os
import cv2
import csv
import nibabel as nib
import pydicom
from pydicom.dataset import FileDataset
import numpy as np
import os
import datetime
import zipfile

def slices2nifti(ims, fn_out, spacing):  
    """Save 2D slices to 3D NIfTI file considering the spacing."""  
    if len(ims) < 300:  # cv2.merge does not support too many channels  
        V = cv2.merge(ims)  
    else:  
        V = np.empty((ims[0].shape[0], ims[0].shape[1], len(ims)))  
        for i in range(len(ims)):  
            V[:, :, i] = ims[i]  

    # The transformation matrix suitable for 3D slicer and ITK-SNAP  
    T = np.array([[0, -spacing[1], 0, 0],   
                  [-spacing[0], 0, 0, 0],   
                  [0, 0, -spacing[2], 0],   
                  [0, 0, 0, 1]])  
    img = nib.Nifti1Image(V, T)  
    path_out = os.path.join(dir_out, fn_out)  
    nib.save(img, path_out)  
    return path_out 



def load_slices(dir, slice_idxs):
    """Load slices from 16-bit PNG files and return images with their filenames."""
    slice_idxs = np.array(slice_idxs)
    if not np.all(slice_idxs[1:] - slice_idxs[:-1] == 1):
        print(f"⚠️ Slice indices are not consecutive")


    ims = []
    filenames = []

    for slice_idx in slice_idxs:
        fn = f'{slice_idx:03d}.png'
        path = os.path.join(dir_in, dir, fn)
        im = cv2.imread(path, -1)  # -1 to preserve 16-bit depth
        assert im is not None, f'Error reading {path}'

        im_corrected = (im.astype(np.int32) - 32768).astype(np.int16)
        ims.append(im_corrected)
        filenames.append(fn.split('.')[0])

    return ims, filenames
 

def read_DL_info():  
    """Read spacings and image indices in DeepLesion."""  
    spacings = []  
    idxs = []  
    with open(info_fn, 'r') as csvfile:  # Use 'r' mode for reading text files  
        reader = csv.reader(csvfile)  
        rownum = 0  
        for row in reader:  
            if rownum == 0:  
                header = row  
                rownum += 1  
            else:  
                idxs.append([int(d) for d in row[1:4]])  
                spacings.append([float(d) for d in row[12].split(',')])  

    idxs = np.array(idxs)  
    spacings = np.array(spacings)  
    return idxs, spacings  

def nii_to_dicom(nii_path, output_folder, filenames):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    affine = nii_img.affine
    num_slices = data.shape[2]

    assert len(filenames) == num_slices, "Length of filename list must match number of slices."

    nii_base = os.path.splitext(os.path.basename(nii_path))[0]
    nii_base = nii_base.split('.')[0]
    dicom_subfolder = os.path.join(output_folder, nii_base)
    os.makedirs(dicom_subfolder, exist_ok=True)

    for i in range(num_slices):
        filename = os.path.join(dicom_subfolder, filenames[i])+'.dcm'  # Use provided filename

        file_meta = pydicom.Dataset()
        file_meta.MediaStorageSOPClassUID = pydicom.uid.SecondaryCaptureImageStorage
        file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
        file_meta.ImplementationClassUID = "1.2.3.4.5.6.7.8.9.0"
        file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian

        ds = FileDataset(filename, {}, file_meta=file_meta, preamble=b"\0" * 128)

        dt = datetime.datetime.now()
        ds.PatientName = "Test^Patient"
        ds.PatientID = "123456"
        ds.Modality = "MR"
        ds.StudyInstanceUID = pydicom.uid.generate_uid()
        ds.SeriesInstanceUID = pydicom.uid.generate_uid()
        ds.SOPInstanceUID = file_meta.MediaStorageSOPInstanceUID
        ds.SOPClassUID = file_meta.MediaStorageSOPClassUID
        ds.StudyDate = dt.strftime('%Y%m%d')
        ds.StudyTime = dt.strftime('%H%M%S')

        ds.Rows, ds.Columns = data.shape[:2]
        ds.InstanceNumber = i + 1
        ds.ImagePositionPatient = [float(affine[0,3]), float(affine[1,3]), float(affine[2,3] + i)]
        ds.ImageOrientationPatient = [1,0,0,0,1,0]
        ds.PixelSpacing = [1.0, 1.0]
        ds.SliceThickness = 1.0
        ds.SamplesPerPixel = 1
        ds.PhotometricInterpretation = "MONOCHROME2"
        ds.BitsAllocated = 16
        ds.BitsStored = 16
        ds.HighBit = 15
        ds.PixelRepresentation = 1

        pixel_array = data[:, :, i].astype(np.uint16)
        ds.PixelData = pixel_array.tobytes()

        ds.is_little_endian = True
        ds.is_implicit_VR = False
        ds.save_as(filename)


def extract_and_collect_main_folders(zip_root_dir):
    extracted_paths = []

    for zip_name in sorted(os.listdir(zip_root_dir)):
        zip_path = os.path.join(zip_root_dir, zip_name)

        # فقط فایل‌های .zip واقعی
        if zip_name.lower().endswith('.zip') and os.path.isfile(zip_path):
            try:
                extract_folder = os.path.join(zip_root_dir, zip_name.replace('.zip', ''))
                os.makedirs(extract_folder, exist_ok=True)

                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_folder)

                os.remove(zip_path)

                # حذف پوشه‌های مخفی
                extracted_subfolders = [
                    os.path.join(extract_folder, name)
                    for name in os.listdir(extract_folder)
                    if os.path.isdir(os.path.join(extract_folder, name)) and not name.startswith('.')
                ]

                if not extracted_subfolders:
                    extracted_paths.append(extract_folder)
                else:
                    extracted_paths.extend(extracted_subfolders)

            except zipfile.BadZipFile:
                print(f"⚠️ Skipping bad zip file: {zip_name}")
                continue

    return extracted_paths


# Main
zip_root_dir ='' 
folders = extract_and_collect_main_folders(zip_root_dir)
def find_image_folders(directory):  
    
    image_folders = []  
    
    # 
    for root, dirs, files in os.walk(directory):  
        for dir_name in dirs:  
            if dir_name.startswith("Images_png_"):  
                
                image_folders.append(os.path.join(root, dir_name))  
    
    return image_folders  

 
folder_path = ''   
result = find_image_folders(folder_path)  

print(result)

dir_out = ''
out_fmt = '%s.nii.gz'  # format of the nifti file name to output
info_fn =''  # file name of the information file
idxs, spacings = read_DL_info()  

for folder in result :
    dir_in = folder + '/Images_png'
    if not os.path.exists(dir_out):  
        os.mkdir(dir_out)  
    img_dirs = os.listdir(dir_in)  
    img_dirs.sort() 

    for dir1 in img_dirs:  
        #Find the image info according to the folder's name    

        idxs1 = np.array([int(d) for d in dir1.split('_')])  
        i1 = np.where(np.all(idxs == idxs1, axis=1))[0]  
        spacings1 = spacings[i1[0]]  

        fns = os.listdir(os.path.join(dir_in, dir1))  
        slices = [int(d[:-4]) for d in fns if d.endswith('.png')]  
        slices.sort()  

        groups = [slices] 

        for group in groups:  
            # Group contains slices indices of a sub-volume  
            ims,names = load_slices(dir1, group)  
            fn_out = out_fmt % (dir1)  
            path_out = slices2nifti(ims, fn_out, spacings1) 
            nii_to_dicom(path_out, "",names)