StanfordMIMI / skm-tea

Repository for the Stanford Knee MRI Multi-Task Evaluation (SKM-TEA) Dataset
MIT License
77 stars 15 forks source link

How to predict the segmentation result of a volume #10

Closed XiaotianJia closed 2 years ago

XiaotianJia commented 2 years ago

Hi, I am very interested in segmentation of skm-tea. But I have two questions. After loading a DICOM data and the weights of a model, what should be done with the data before it is sent to the network? And how to send the data into the network to predict the mask of the data? Is there an example to show the process?

XiaotianJia commented 2 years ago

Hi, I am very interested in segmentation of skm-tea. But I have two questions. After loading a DICOM data and the weights of a model, what should be done with the data before it is sent to the network? And how to send the data into the network to predict the mask of the data? Is there an example to show the process? I would appreciate it if you could answer my questions.

ad12 commented 2 years ago

Hi @XiaotianJia, thanks for this question. We recommend a volumetric zero-mean, unit-standard deviation normalization. We also recommend using the dosma library to load dicoms and orient scans.

I've included some example code to do this with models taking a single-channel input (e.g. rss, E1, E2) would be like below. The code assumes you are using the root-sum-of-squares (rss) of the two echoes. However, the volume can be replaced with whichever echo/echoes is/are the target for segmentation.

import skm_tea as st
import dosma as dm
import torch
import meddlr.ops as O

# Read the qDESS dicoms
echo1, echo2 = tuple(dm.read("/path/to/qdess/dicoms", group_by="EchoNumbers"))
echo1, echo2 = echo1.astype(np.float32), echo2.astype(np.float32)

# If you are using the rss, compute it here
rss = np.sqrt(echo1 ** 2 + echo2 ** 2)

# Put sagittal dimension first.
rss = rss.reformat(("LR", "SI", "AP"))

# Convert to torch tensor and add channel dimension.
# Sagittal dimension is the batch dimension.
# Note: An array copy is used because torch cannot handle negative strides, which are possible.
rss_tensor = torch.from_numpy(rss.A.copy()).unsqueeze(1)  # Shape: (Z, 1, X, Y)
rss_tensor = (rss_tensor - rss_tensor.mean()) / rss_tensor.std()

# Get the U-Net RSS model (feel free to replace with the model of your choice here)
model = st.get_model_from_zoo(
    "download://https://drive.google.com/file/d/1HLOXapcKSqfVr_2mSxR_OlgUXyarkkNd/view?usp=sharing",
    "download://https://drive.google.com/file/d/1QAcdPSh1VQ967Q1loiiZRFo0JBZb5yIs/view?usp=sharing",
).eval()

# Run the model
with torch.no_grad():
    output = model({"image": rss_tensor})

# Get the predictions
pred = O.logits_to_prob(output["sem_seg_logits"], "sigmoid")

# (Optional) Put them back into a MedicalVolume
seg_volume = dm.MedicalVolume(pred.permute((0,2,3,1)), affine=rss.affine)  # Shape: (Z, X, Y, C)

If you would like to use a pretrained models out of the box without having to worry about managing normalization, orientation, etc., the DOSMA library has a version of the model trained in tensorflow/keras that should be plug-and-play. This tutorial highlights how to use this model.

XiaotianJia commented 2 years ago

Thank you very much for your detailed answer, and it is very helpful to me.