alibaba-damo-academy / ct-sam3d

Apache License 2.0
16 stars 2 forks source link

Parameters of the model #3

Closed jianjun0407 closed 1 month ago

jianjun0407 commented 1 month ago

I am happy to share with you that your code has been tested successfully, and the speed is really fast, almost close to the real-time level! Now I have two questions I would like to discuss with you:

  1. Other trained model parameters, such as rest_v2_large and rest_v2_base, would you consider releasing them later?
  2. In the test code, the implementation method is based on the point input by the user to extract the patch for testing? Can be tested in a sliding window way of the current code? Because some regions of interest, such as lung parenchyma, its volume is larger than 64x64x64.
jianjun0407 commented 1 month ago
  1. I tried to test the algorithm under pycharm instead of JupyterLab and itkwidgets.The predicted mask is somewhat abnormal, as shown in the following figure,I thought maybe the coordinates of the prompt point were wrong. error_result

The following is the code, please also guide

-- coding:utf-8 --

import SimpleITK as sitk import itkwidgets import torch import numpy as np from itkwidgets import view

from ct_sam.utils.frame import voxel_to_world from ct_sam.builder import build_sam from ct_sam.predictor import SamPredictor from ct_sam.utils.resample import flip_itkimage_torai, resample_itkimage_torai, crop_roi_with_center

import os import numpy as np from ct_sam.utils.io_utils import load_module_from_file

from ipywidgets import interact_manual, widgets, interactive_output, Button, VBox, HBox from IPython.display import display import numpy as np

from ipywidgets import interact_manual, widgets, interactive_output, Button, VBox, HBox from IPython.display import display import numpy as np

if name=="main":

============= (1)读入测试图像 =============

#(1.1)image
pid = "FLARE22_Tr_0001"
image = sitk.ReadImage(f"F:/ct-sam3d/ct_sam/examples/{pid}_0000.nii.gz")
image = resample_itkimage_torai(image, [1.5, 1.5, 1.5], "linear", -1024)

#(1.2)mask
mask_gt = None
try:
    mask_gt = sitk.ReadImage(f"F:/ct-sam3d/ct_sam/examples/{pid}.nii.gz")  #(f"../examples/{pid}.nii.gz")
    mask_gt = resample_itkimage_torai(mask_gt, [1.5, 1.5, 1.5], interpolator="nearest", pad_value=0)

except Exception as e:
    print(f"read mask failed: {e}")

print("image loaded!")

#(2)============= 加载模型参数 =============
#(2.1)path
checkpoint = "F:/ct-sam3d/checkpoint/ckpt_1000/params.pth"#"../../../ckpt_1000/params.pth"
config_file = os.path.join(os.path.dirname(checkpoint), "config.py")
assert os.path.isfile(config_file), "file config.py not found!"

#(2.2)加载网络结构(import config.py)
cfg_module = load_module_from_file(config_file)
cfg = cfg_module.cfg

#(2.3)加载模型参数
cfg.update({"checkpoint": checkpoint})
sam = build_sam(cfg)  #"ResTV2"
if torch.cuda.is_available():
    sam.cuda()
predictor = SamPredictor(sam, cfg.dataset)
print("predictor device: ", predictor.device)

#(3)============= 测试 =============
#(3.1)image patch
center_v = [68, 134, 92]
center_w = voxel_to_world(image, center_v)
x_axis, y_axis, z_axis = np.array(image.GetDirection()).reshape(3,3).transpose()  # x_axis=(1,0,0) y_axis=(0,1,0) z_axis=(0,0,1)
image_patch = crop_roi_with_center(image, center_w, image.GetSpacing(), x_axis, y_axis, z_axis, [64, 64, 64],"linear", -1024)

#(3.2)testing
predictor.set_image(image_patch)
point_coords = np.array([[33, 39, 32]])  #shape=(1,3)
point_labels = np.array([1])  #shape=(1,)
maskprdict, probabliy, result3 = predictor.predict(point_coords = point_coords,
                                                   point_labels = point_labels)

#(3.3)predict mask
resultMask = sitk.GetArrayFromImage(maskprdict)  #unint8
# resultMask = np.transpose(resultMask, (2,1,0))
patchImg = sitk.GetArrayFromImage(image_patch)  #float32

resultMask.tofile("./resultMask.raw")
patchImg.tofile("./patchImg.raw")

#debug
d = 10
jianjun0407 commented 1 month ago

Happily, the third problem is solved by setting multimask_output=False!

henguo commented 1 month ago

It's great to hear that you have successfully tested the model. Regarding the first issue, we currently do not have plans to release a larger model. As for the second issue, our interaction is driven by 'click', rather than the sliding window approach used in fully automated segmentation methods. For large organs, providing more points can typically resolve the issue. On the algorithm side, we have designed a cross-patch prompting module to address this issue, but to keep the interaction logic simple and clear, we have not integrated this part into the interactive tool, and manual clicking remains the primary interaction method.