Organ-aware 3D lesion segmentation dataset and pipeline for abdominal CT analysis (ACM Multimedia 2025 candidate)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sam_mask_generator.py 2.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import cv2
  5. import torch
  6. from tqdm import tqdm
  7. from segment_anything import sam_model_registry, SamPredictor
  8. CSV_PATH = ""
  9. CT_IMAGES_ROOT = "" #deeplesion images
  10. OUTPUT_MASK_DIR = "sam_lesion_2Dmasks_compound_box"
  11. os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
  12. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  13. model_type = "vit_h"
  14. checkpoint_path = "sam_vit_h_4b8939.pth"
  15. sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
  16. sam.to(device=device)
  17. predictor = SamPredictor(sam)
  18. df = pd.read_csv(CSV_PATH)
  19. for idx, row in tqdm(df.iterrows(), total=len(df)):
  20. try:
  21. series = "_".join(row["File_name"].split("_")[:3])
  22. key_slice = int(row["key_slice"])
  23. bbox = list(map(float, row["Bounding_boxes"].strip("[]").split(",")))
  24. bbox = list(map(int, bbox))
  25. x1, y1, x2, y2 = bbox
  26. slice_filename = f"{key_slice:03}.png"
  27. image_path = os.path.join(CT_IMAGES_ROOT, series, slice_filename)
  28. if not os.path.exists(image_path):
  29. print(f"There is no Image {image_path}")
  30. continue
  31. image_bgr = cv2.imread(image_path)
  32. image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
  33. predictor.set_image(image_rgb)
  34. outer_box = np.array([x1, y1, x2, y2])
  35. box_w, box_h = x2 - x1, y2 - y1
  36. cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
  37. scale = 0.3
  38. cw, ch = int(box_w * scale), int(box_h * scale)
  39. center_box = np.array([cx - cw // 2, cy - ch // 2, cx + cw // 2, cy + ch // 2])
  40. mask_outer, _, _ = predictor.predict(
  41. box=outer_box[None, :],
  42. multimask_output=False
  43. )
  44. mask_center, _, _ = predictor.predict(
  45. box=center_box[None, :],
  46. multimask_output=False
  47. )
  48. final_mask = np.logical_or(mask_outer[0], mask_center[0]).astype(np.uint8) * 255
  49. out_filename = f"{series}_{key_slice:03}.png"
  50. out_path = os.path.join(OUTPUT_MASK_DIR, out_filename)
  51. cv2.imwrite(out_path, final_mask)
  52. except Exception as e:
  53. print(f"⚠️ error{idx}: {e}")