wasserth / TotalSegmentator

Tool for robust segmentation of >100 important anatomical structures in CT and MR images
Apache License 2.0
1.42k stars 237 forks source link

Model weights on Hugging Face #201

Open katielink opened 11 months ago

katielink commented 11 months ago

Hi! This is really amazing work! It would be awesome to have the model weights shared on Hugging Face. Here's some more information how to do so: https://huggingface.co/docs/hub/models-uploading Happy to help with any questions!

wasserth commented 11 months ago

Thanks for the suggestion. Huggingface is doing a great job with hosting machine learning models. However, I think this model is a bit different from your typical pytorch model. Just hosting the weights on huggingface will not really help any of our users.

katielink commented 11 months ago

Thanks for your response! Curious to hear why this model is different from your typical pytorch model? I get asked quite frequently about medical image segmentation models, and I think even having this model searchable within the hub would be useful for others in the ML community who might not yet be aware of this awesome work. If it's okay with you, I'd be happy to help put the weights for the TotalSegmentator models on HF (of course, giving you full and proper attribution, linking back to this repository, including the appropriate license, etc). Let me know what you think.

pangyuteng commented 9 months ago

@katielink to my knowledge, inference of all organs requires running inference for several models, further, a bit of a post processing is involved (see https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/nnunet.py#L331 ).

On top of that, you'll need a a decent GPU, or else inference will take a while (i'm guessing, 20 to 40min, depending on your CPU spec) So if people are going for the huggingface inference endpoint route, they will need to up the instance spec 💸. Most people/researchers using this have a tight budget and/or likely already have on-prem gpu/HPCs.

That said, I've gave it a quick shot (see below sample huggingface space/gradio impl), since I'm curious if the free hardware can execute totalsegmentator without error. I think you can convert it to a "huggingface model repo" with custom handler. One lacking logic is getting the weights to be cached in the deployed docker container / model repo. Also, I have to add the "--fast" flag, or else it will timeout at 30min when deployed to huggingface.co using the free instance (2cpu ram16gb).

-- requirements.txt: #TODO add version

totalsegmentator

app.py:

import os
import sys
import tempfile
import subprocess
import gradio as gr
import totalsegmentator

if os.environ.get("WEIGHTS_CACHED") != "TRUE":
    print('downloading pretrained weights...')
    subprocess.call("totalseg_download_weights -t total",shell=True)
    os.environ["WEIGHTS_CACHED"]="TRUE"

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLE_NIFTI_PATH = os.path.join(THIS_DIR,'files','sample-image.nii.gz')

def myfunc(file_obj):    
    file_list = []
    input_file_path = file_obj.name
    with tempfile.TemporaryDirectory() as tmpdir:
        output_folder_path = os.path.join(tmpdir,"segmentations")
        cmd_str = f"TotalSegmentator -i {input_file_path} -o {output_folder_path} --fast"
        subprocess.call(cmd_str,shell=True)
        if os.path.exists(output_folder_path):
            file_list = os.listdir(output_folder_path)
        return {"status":file_list} # TODO: maybe add papaya js element to view nifti

if __name__ == "__main__":
    demo = gr.Interface(myfunc, ["file"], "json",examples=[EXAMPLE_NIFTI_PATH],cache_examples=True)
    demo.queue().launch(debug=True,show_api=True,share=False)

image