Closed jianjun0407 closed 1 month ago
The following is the code, please also guide
import SimpleITK as sitk import itkwidgets import torch import numpy as np from itkwidgets import view
from ct_sam.utils.frame import voxel_to_world from ct_sam.builder import build_sam from ct_sam.predictor import SamPredictor from ct_sam.utils.resample import flip_itkimage_torai, resample_itkimage_torai, crop_roi_with_center
import os import numpy as np from ct_sam.utils.io_utils import load_module_from_file
from ipywidgets import interact_manual, widgets, interactive_output, Button, VBox, HBox from IPython.display import display import numpy as np
from ipywidgets import interact_manual, widgets, interactive_output, Button, VBox, HBox from IPython.display import display import numpy as np
if name=="main":
#(1.1)image
pid = "FLARE22_Tr_0001"
image = sitk.ReadImage(f"F:/ct-sam3d/ct_sam/examples/{pid}_0000.nii.gz")
image = resample_itkimage_torai(image, [1.5, 1.5, 1.5], "linear", -1024)
#(1.2)mask
mask_gt = None
try:
mask_gt = sitk.ReadImage(f"F:/ct-sam3d/ct_sam/examples/{pid}.nii.gz") #(f"../examples/{pid}.nii.gz")
mask_gt = resample_itkimage_torai(mask_gt, [1.5, 1.5, 1.5], interpolator="nearest", pad_value=0)
except Exception as e:
print(f"read mask failed: {e}")
print("image loaded!")
#(2)============= 加载模型参数 =============
#(2.1)path
checkpoint = "F:/ct-sam3d/checkpoint/ckpt_1000/params.pth"#"../../../ckpt_1000/params.pth"
config_file = os.path.join(os.path.dirname(checkpoint), "config.py")
assert os.path.isfile(config_file), "file config.py not found!"
#(2.2)加载网络结构(import config.py)
cfg_module = load_module_from_file(config_file)
cfg = cfg_module.cfg
#(2.3)加载模型参数
cfg.update({"checkpoint": checkpoint})
sam = build_sam(cfg) #"ResTV2"
if torch.cuda.is_available():
sam.cuda()
predictor = SamPredictor(sam, cfg.dataset)
print("predictor device: ", predictor.device)
#(3)============= 测试 =============
#(3.1)image patch
center_v = [68, 134, 92]
center_w = voxel_to_world(image, center_v)
x_axis, y_axis, z_axis = np.array(image.GetDirection()).reshape(3,3).transpose() # x_axis=(1,0,0) y_axis=(0,1,0) z_axis=(0,0,1)
image_patch = crop_roi_with_center(image, center_w, image.GetSpacing(), x_axis, y_axis, z_axis, [64, 64, 64],"linear", -1024)
#(3.2)testing
predictor.set_image(image_patch)
point_coords = np.array([[33, 39, 32]]) #shape=(1,3)
point_labels = np.array([1]) #shape=(1,)
maskprdict, probabliy, result3 = predictor.predict(point_coords = point_coords,
point_labels = point_labels)
#(3.3)predict mask
resultMask = sitk.GetArrayFromImage(maskprdict) #unint8
# resultMask = np.transpose(resultMask, (2,1,0))
patchImg = sitk.GetArrayFromImage(image_patch) #float32
resultMask.tofile("./resultMask.raw")
patchImg.tofile("./patchImg.raw")
#debug
d = 10
Happily, the third problem is solved by setting multimask_output=False!
It's great to hear that you have successfully tested the model. Regarding the first issue, we currently do not have plans to release a larger model. As for the second issue, our interaction is driven by 'click', rather than the sliding window approach used in fully automated segmentation methods. For large organs, providing more points can typically resolve the issue. On the algorithm side, we have designed a cross-patch prompting module to address this issue, but to keep the interaction logic simple and clear, we have not integrated this part into the interactive tool, and manual clicking remains the primary interaction method.
I am happy to share with you that your code has been tested successfully, and the speed is really fast, almost close to the real-time level! Now I have two questions I would like to discuss with you: