Teichlab / bin2cell

Join subcellular Visium HD bins into cells
MIT License
61 stars 2 forks source link

how to load local model for b2c.stardist function #28

Open dfgao opened 5 days ago

dfgao commented 5 days ago

hi bin2cell team,

My server can't access github to download python_2D_versatile_he.zip. I downloaded the file from another computer and loaded it in python, but b2c.stardist still had to download it from github. Here's my code.

import os \ model_dir = "/root/.keras/models/StarDist2D/2D_versatile_he" \ files = os.listdir(model_dir) \ print(files) ['thresholds.json', 'weights_best.h5', 'config.json']

from stardist.models import StarDist2D \ from pathlib import Path \ model_dir = Path("/root/.keras/models/StarDist2D/2D_versatile_he") \ if not model_dir.exists(): \ raise FileNotFoundError(f"Model path does not exist: {model_dir}") \ model = StarDist2D(None, name=str(model_dir))

Loading network weights from 'weights_best.h5'. Loading thresholds from 'thresholds.json'. Using default values: prob_thresh=0.692478, nms_thresh=0.3.

b2c.stardist(image_path="stardist/he.jpg", labels_npz_path="stardist/he.npz", stardist_model="2D_versatile_he", prob_thresh=0.01 )

Exception: URL fetch failure on https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_he.zip: None -- [Errno 101] Network is unreachable

Do you have any better suggestions?

ktpolanski commented 5 days ago

Interesting, you're not the first person to run into this. The other one came to me via email so there's no track record of this on GitHub.

Here's a workaround you could try using, which accepts the model as a function argument. As a result, since we no longer have the name of the model, you need to pass model_axes to go with it - for H&E that's "YXC", for IF/GEX that's "YX".

import scipy.sparse
import bin2cell as b2c

def stardist_custom(image_path, labels_npz_path, model, model_axes, block_size=4096, min_overlap=128, context=128, **kwargs):
    #using stardist models requires tensorflow, avoid global import
    from stardist.models import StarDist2D
    #load and percentile normalize image, following stardist demo protocol
    #turn it to np.float16 pre normalisation to keep RAM footprint minimal
    #determine whether to greyscale or not based on the length of model_axes
    #YX is length 2, YXC is length 3
    img = b2c.load_image(image_path, gray=(len(model_axes) == 2), dtype=np.float16)
    img = b2c.normalize(img)
    #we passed a custom model as model already
    #we also passed model_axes to go with it so we're good
    model = StarDist2D.from_pretrained(stardist_model)
    #run predict_instances_big() to perform automated tiling of the input
    #this is less parameterised than predict_instances, needed to pass axes too
    #pass any other **kwargs to the thing, passing them on internally
    #in practice this is going to be prob_thresh
    labels, _ = model.predict_instances_big(img, axes=model_axes, 
                                            block_size=block_size, 
                                            min_overlap=min_overlap, 
                                            context=context, 
                                            **kwargs
                                           )
    #store resulting labels as sparse matrix NPZ - super efficient space wise
    labels_sparse = scipy.sparse.csr_matrix(labels)
    scipy.sparse.save_npz(labels_npz_path, labels_sparse)
    print("Found "+str(len(np.unique(labels_sparse.data)))+" objects")

Let me know how you get on with it.

dfgao commented 5 days ago

Wow, its work. here's my code:

Add the following code to the bin2cell.py file and comment out the def stardist function:

def stardist_custom(image_path, labels_npz_path, model, model_axes, block_size=4096, min_overlap=128, context=128, **kwargs):
    #using stardist models requires tensorflow, avoid global import
    from stardist.models import StarDist2D
    #load and percentile normalize image, following stardist demo protocol
    #turn it to np.float16 pre normalisation to keep RAM footprint minimal
    #determine whether to greyscale or not based on the length of model_axes
    #YX is length 2, YXC is length 3
    img = load_image(image_path, gray=(len(model_axes) == 2), dtype=np.float16)
    img = normalize(img)
    #we passed a custom model as model already
    #we also passed model_axes to go with it so we're good
#    model = StarDist2D.from_pretrained(stardist_model)
    #run predict_instances_big() to perform automated tiling of the input
    #this is less parameterised than predict_instances, needed to pass axes too
    #pass any other **kwargs to the thing, passing them on internally
    #in practice this is going to be prob_thresh
    labels, _ = model.predict_instances_big(img, axes=model_axes, 
                                            block_size=block_size, 
                                            min_overlap=min_overlap, 
                                            context=context, 
                                            **kwargs
                                           )
    #store resulting labels as sparse matrix NPZ - super efficient space wise
    labels_sparse = scipy.sparse.csr_matrix(labels)
    scipy.sparse.save_npz(labels_npz_path, labels_sparse)
    print("Found "+str(len(np.unique(labels_sparse.data)))+" objects")

Then, run the following commands in the Jupyter terminal:

from stardist.models import StarDist2D
from pathlib import Path

model_dir = Path("/root/.keras/models/StarDist2D/2D_versatile_he") # downloaded the **python_2D_versatile_he** file and unzip in this folder.

if not model_dir.exists():
    raise FileNotFoundError(f"Model path does not exist: {model_dir}")

model = StarDist2D(None, name=str(model_dir))
print("Model loaded successfully!")

b2c.stardist_custom(
    image_path="stardist/he.jpg", 
    labels_npz_path="stardist/he.npz",
    model=model,  
    model_axes="YXC",
    prob_thresh=0.01
)

output:

effective: block_size=(4096, 4096, 3), min_overlap=(128, 128, 0), context=(128, 128, 0)
functional.py (225): The structure of `inputs` doesn't match the expected structure: ['input']. Received: the structure of inputs=*
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:28<00:00,  9.30s/it]

Thank you for your help 😀

ktpolanski commented 5 days ago

Awesome. I'll keep this issue open to remind myself to figure out an elegant way to support custom StarDist models, as this sorts out corner cases like you and technically grants users extra possibilities as well.

dfgao commented 5 days ago

Thank you so much! I have another question: My area contains two non-overlapping tissues, one on the left and one on the right. Do you have any suggestions on how to separately generate cdata for each tissue?

ktpolanski commented 5 days ago

This shouldn't be a problem in bin2cell itself, as non-overlapping implies that the segmented objects should be quite far apart.

Do a simple scatterplot of cdata.obsm["spatial"] and see if you can define a simple thresholding to split them up into two. You'd think that's the most likely.

dfgao commented 5 days ago

Thank you again for your suggestion. I will try it out soon. 👍