Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.76k stars 666 forks source link

The tutorial of Incorporating ONNX Support into Brain Tumor Segmentation #1696

Closed ctongh closed 3 months ago

ctongh commented 4 months ago

Fixes #1681

Description

Issue #1681 refers to a tutorial based on brats_segmentation_3d adding pytorch to onnx with that compares the differences between the two models.

Checks

review-notebook-app[bot] commented 4 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

ctongh commented 4 months ago

@KumoLiu I would like to ask if it is possible to assist with a review? Please let me know if there are any additions or modifications that need to be made to the tutorial. Thank you!

KumoLiu commented 4 months ago

Hi @ctongh, sorry for the late feedback. Thanks for your contribution. Given the similarities with 3d_image_segmentation, do you consider integrating this with the original one and add more description in the README: https://github.com/Project-MONAI/tutorials/blob/main/README.md#brats_segmentation_3d. What do you think?

cc @ericspod @Nic-Ma

ctongh commented 4 months ago

Hi @KumoLiu, Thanks for the quick reply!

Originally, I referenced the spleen_segmentation_3d and spleen_segmentation_3d_lightning tutorials, which are similar extended relationship between them(extension). Therefore, I added a tutorial named brats_segmentation_3d_onnx, besides the additional parts, I made some adjustments to the original code to be compatible with onnxruntime as follows:

# define inference method
# The reason why we don't use sliding window inference here is to avoid sizing errors when converting a pytorch model to an onnx model.
simple_inferer = SimpleInferer()
def inference(input):
    def _compute(input):
+       return simple_inferer(input, model)
-        return sliding_window_inference(
-            inputs=input,
-            roi_size=(240, 240, 160),
-            sw_batch_size=1,
-            predictor=model,
-            overlap=0.5,
-        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

However, in the monai/tutorial directory, I found that sliding_window_inference still more common. So if we plan to integrate brats_segmentation_3d_onnx into the original brats_segmentation_3d, perhaps we can retain the original sliding_window_inference code snippet (commented out).

Additionally, regarding the README, upon browsing other tutorial READMEs, they seem to provide only a brief introduction. If the above integration is not an issue, I will promptly upload the updated README.

Note: We actually trained for 350 epochs, and the results show that the performance of the two different inference methods mentioned above is the same. (We will be happy to provide the parameter files if needed.)

KumoLiu commented 4 months ago

So if we plan to integrate brats_segmentation_3d_onnx into the original brats_segmentation_3d, perhaps we can retain the original sliding_window_inference code snippet (commented out).

Agree.

Additionally, regarding the README, upon browsing other tutorial READMEs, they seem to provide only a brief introduction. If the above integration is not an issue, I will promptly upload the updated README.

Yes, just add maybe one or two sentence mentioned that we have include onnx support in that tutorial.

Thanks!

ericspod commented 4 months ago
  • Replaced the commonly used sliding_window_inference in medical imaging with a fixed-size SimpleInferer.

Hi @ctongh thanks for the contribution. I think that this works in this case because the sliding window is very close in size to the input and so only a small amount of the image edge is lost when using direct inference rather than sliding window. If there were features at the edges this wouldn't work so well, it's not a solution in general as a result.

This notebook copies the content from the original notebook a lot. We had discussed it and felt that it might be better to either put the ONNX related cells at the end of the original notebook, or remove all the duplicate cells from this one and put a note saying to run the BRATS notebook first to generate the checkpoint needed to make the ONNX model. This would reduce the size of code being added to the repo. How do you feel about either of those changes?

ctongh commented 4 months ago

Hi @ericspod thank you for your advice. We discussed and agreed that it would be better to merge the onnx support into the original file brats_segmentation_3d, and there are two options for sliding_window_inference to be replaced by SimpleInferer.

  1. create a helper function so that users can decide whether they want to use sliding_window_inference or SimpleInferer.
  2. comment out thesliding_window_inference part, and mention in the description that in most cases sliding_window_inference is the more common way (explain why).

Considering that tutorials often want to provide users with a simple and easy to implement experience, we think that option 2 is a better approach (avoiding extra operations), what do you think?

ericspod commented 4 months ago

Hi @ctongh, I think it's possible to use sliding_window_inference with ONNX by wrapping the ONNX call in a function which is passed instead of the network itself. This would look something like this with a function onnx_infer to act as the predictor for the sliding_window_inference call:

onnx_model_path = os.path.join(root_dir, "best_metric_model.onnx")
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# wrap the call in something that will convert to and from formats as needed:
def onnx_infer(inputs):
    ort_inputs = {ort_session.get_inputs()[0].name: inputs.cpu().numpy()}
    ort_outs = ort_session.run(None, ort_inputs)
    return torch.Tensor(ort_outs[0]).to(inputs.device)

...

# use this function with sliding inference like so:
prediction = sliding_window_inference(
        inputs=input,
        roi_size=(240, 240, 160),
        sw_batch_size=1,
        predictor=onnx_infer,
        overlap=0.5,
)

With this I think you can make an adapted version of the inference cell from the original notebook to use sliding_window_inference in this way rather than the inference mentioned there. This and the other ONNX related code I think you can put at the end of the existing notebook which will replicate all the functionality with an ONNX model and demonstrate how to make that work with existing utilities. Let me know if that makes sense and will work (I admit I haven't tried it).

ctongh commented 3 months ago

Hi @ericspod ,thank you for your advice. Since the validation data size is fixed at [240, 240, 155], and the sliding_window_inference.roi_size is set to [240, 240, 160], this configuration perfectly covers the entire validation data (i.e., the inference process has no sliding window). This behavior is very similar to the SimpleInferer, where both methods use fixed-size images for inference. Packaging it as onnx_infer sounds like a good approach!. I will update the research results in a few days.

ctongh commented 3 months ago

Hi @ericspod , implementing this method has yielded excellent results, and the code is now much more concise. Could you please assist with a review? If any modifications are needed, feel free to discuss them with me. Thank you!

ctongh commented 3 months ago

Hi @ericspod , I found that the error occurred in cell 29. This cell is used for inferring the PyTorch model (which hasn't been changed). However,It faild the check because of timeout, I believe the longer inference time is due to the size of the dataset. Please take a look at the remaining issues. Thank you!

ericspod commented 3 months ago

Hi @ericspod , I found that the error occurred in cell 29. This cell is used for inferring the PyTorch model (which hasn't been changed). However,It faild the check because of timeout, I believe the longer inference time is due to the size of the dataset. Please take a look at the remaining issues. Thank you!

Downloading Brats locally is super slow at the moment for me and that could be part of it. I'm trying to rerun the job and we'll see if it can get through, it shouldn't take 6 hours even with this download issue. If it still fails you can test things locally with ./runner.sh -t 3d_segmentation/brats_segmentation_3d.ipynb.

ctongh commented 3 months ago

Hi @ericspod @KumoLiu , I changed the verbose parameter to False in torch.onnx.export and found that it significantly reduced the runtime for this cell during my tests.

When I used runner.sh to test, I noticed that even the original notebook takes a very long time to run. This is likely due to the large dataset size and the bigger sliding window (I also tested other notebooks with similar features, such as spleen_segmentation_3d , and encountered the same issue.). However, this issue doesn't seem to be related to ONNX support. Therefore, we can consider merging this change first. I will open a separate issue to track this problem. Please let me know your thoughts. Thank you!