Open constantinpape opened 9 months ago
Great !!! thanks a lot :)
I will try to make it "napari compatible" using viewer.layers as well :)
Thanks !!!
my attempt on a 2d image :
import numpy as np from micro_sam.util import get_sam_model from micro_sam.inference import batched_inference from skimage.measure import regionprops
image = viewer.layers[0] initial_segmentation = viewer.layers[1]
labels = initial_segmentation.data
labeled_image = np.array(labels)
image_data = viewer.layers[1].data
image = np.array(image_data)
props = regionprops(labeled_image) points = np.array([prop.centroid for prop in props])[:, ::-1] # The coordinates of the centroids need to be reversed to match the convention of SAM. point_labels = np.ones(len(points), dtype="int") # <- All prompts are positive.
predictor = get_sam_model(model_type="vit_b_lm") # <- you can control which model is used with the model type argument.
refined_segmentation = batched_inference( predictor, image, batch_size=32, # This controls how many points are processed at once, lower it if you get memory issues points=points, point_labels=point_labels, return_instance_segmentation=True )
RuntimeError Traceback (most recent call last)
Cell In[3], line 29
26 predictor = get_sam_model(model_type="vit_b_lm") # <- you can control which model is used with the model type argument.
27 # See the function signature of get_sam_model for details.
---> 29 refined_segmentation = batched_inference(
30 predictor, image,
31 batch_size=32, # This controls how many points are processed at once, lower it if you get memory issues
32 points=points,
33 point_labels=point_labels,
34 return_instance_segmentation=True
35 )
File ~\AppData\Local\micro_sam\Lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~\AppData\Local\micro_sam\Lib\site-packages\micro_sam\inference.py:115, in batched_inference(predictor, image, batch_size, boxes, points, point_labels, multimasking, embedding_path, return_instance_segmentation, segmentation_ids, reduce_multimasking)
112 batch_points = points[batch_start:batch_stop] if have_points else None
113 batch_labels = point_labels[batch_start:batch_stop] if have_points else None
--> 115 batch_masks, batch_ious, _ = predictor.predict_torch(
116 point_coords=batch_points, point_labels=batch_labels,
117 boxes=batch_boxes, multimask_output=multimasking
118 )
120 # If we expect to reduce the masks from multimasking and use multi-masking,
121 # then we need to select the most likely mask (according to the predicted IOU) here.
122 if reduce_multimasking and multimasking:
File ~\AppData\Local\micro_sam\Lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\predictor.py:222, in SamPredictor.predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
219 points = None
221 # Embed prompts
--> 222 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 points=points,
224 boxes=boxes,
225 masks=mask_input,
226 )
228 # Predict masks
229 low_res_masks, iou_predictions = self.model.mask_decoder(
230 image_embeddings=self.features,
231 image_pe=self.model.prompt_encoder.get_dense_pe(),
(...)
234 multimask_output=multimask_output,
235 )
File ~\AppData\Local\micro_sam\Lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\micro_sam\Lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\modeling\prompt_encoder.py:155, in PromptEncoder.forward(self, points, boxes, masks)
153 if points is not None:
154 coords, labels = points
--> 155 point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 if boxes is not None:
File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\modeling\prompt_encoder.py:84, in PromptEncoder._embed_points(self, points, labels, pad)
82 padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
---> 84 points = torch.cat([points, padding_point], dim=1)
85 labels = torch.cat([labels, padding_label], dim=1)
86 point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
RuntimeError: Tensors must have same number of dimensions: got 2 and 3
not sure how to fix the dimension problem ...?
thanks antho
It turns out an extra dimension needs to be added to the inputs of batched_prediction
:
# We need to add an extra dimension to provide the correct input for batched_prediction.
points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)
(I have updated the pseudo-code on top too, so you can see the full example there.)
Thanks a lot,
there is a shape argument missing in the function ( mask_data_to_segmentation) , will work to fix it. Almost there ...!
thanks a lot, antho
Ok that seems to work : (tested roughly!)
I aslo added in the inference.py file (shape attribute was missing) :
if return_instance_segmentation:
masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0, shape=image_shape)"
import numpy as np import napari from micro_sam.util import get_sam_model from micro_sam.inference import batched_inference from skimage.measure import regionprops
image = viewer.layers[0] initial_segmentation = viewer.layers[1]
labels = initial_segmentation.data
labeled_image = np.array(labels)
image_data = viewer.layers[1].data
image = np.array(image_data)
props = regionprops(labeled_image) points = np.array([prop.centroid for prop in props])[:, ::-1] # The coordinates of the centroids need to be reversed to match the convention of SAM. point_labels = np.ones(len(points), dtype="int") # <- All prompts are positive.
points = np.expand_dims(points, 1) point_labels = np.expand_dims(point_labels, 1)
predictor = get_sam_model(model_type="vit_b_lm") # <- you can control which model is used with the model type argument.
refined_segmentation = batched_inference( predictor, image, batch_size=32, # This controls how many points are processed at once, lower it if you get memory issues points=points, point_labels=point_labels, return_instance_segmentation=True )
viewer.add_labels(refined_segmentation, name='Refined Segmentation Labels')
viewer.layers[-1].colormap = 'viridis' viewer.layers[-1].opacity = 0.5
will test more later on,
Thanks a lot !! :) antho
Ok that seems to work : (tested roughly!)
Ok, great! Let me know how the quality looks. If there are any issues this can probably be improved by adjusting some parameters.
I aslo added in the inference.py file (shape attribute was missing) :
This should not be necessary if you're working of the dev
branch:
https://github.com/computational-cell-analytics/micro-sam/blob/dev/micro_sam/instance_segmentation.py#L51
(dev
contains the latest version, and we will merge it into the master
branch soon).
But that is only a minor thing, just be aware that this might change soon on master
too.
the results are slightly different , hence I think I have to adjust the parameters as you suggested, but it works in principle :D .
I think the point by side (default is 32) gives better results using 100 (more granular), but any customable parameters will be useful :).
I can share the image / pre-sam output if that helps?
The main idea, is to refine the segmentation for the elongated cells (often the sides are not well segmented), but also refine doublets and general fine segmentation. In addition it will be really cool to be able to add points (automatically) for any missing cells from the pre-segmentation. That way it will do 2 things : refine the existing segmentation and add the missing cells (1 stone , 2 birds..!) .
I am still on the master branch , but will switch to the dev one ,
what about for 3d (my main interest) ? Thanks a lot :) antho
the results are slightly different , hence I think I have to adjust the parameters as you suggested, but it works in principle :D .
That's great!
I can share the image / pre-sam output if that helps?
Yes, that would be quite helpful!
In addition it will be really cool to be able to add points (automatically) for any missing cells from the pre-segmentation.
Do you have a good heuristic for how to adding points automatically?
what about for 3d (my main interest) ?
I will follow up on that next week. (I am on a retreat this week, so my answers are a bit slower, but I will be working on this next week anyways and share some code.)
Hi,
I sent you an invite to share the files to your email, I included the original image, my custom 2d model masks and the refinements from microsam, as mentioned the main improvement could be with elongated nucleus, that will be great to refine these :) .
For adding points , I was part of the last HTAN jamboree (https://github.com/NCI-HTAN-Jamborees/Improving-cell-segmentation-for-spatial-omics/tree/main) , we worked on similar approaches, and I know there are a few papers working on the idea , using the specialized models as promts (as we are doing now) , but adding the automatic grid points on top in case that the specialized model missed some nuclei (grid worked better with 100 points if i remember correctly). I will dig into finding these papers later on .
No probs for the delay, the 3d is the most time consuming, hence any help will be appreciated :) Retreat , means holidays, hence no work :), enjoy and talk next week then !
thanks. antho
I sent you an invite to share the files to your email, I included the original image, my custom 2d model masks and the refinements from microsam, as mentioned the main improvement could be with elongated nucleus, that will be great to refine these :) .
Thanks for sending the data. Unfortunately the service you used for sharing seems to require a client for downloading that is not available for linux (and I use a linux machine). Could you share it with a different service that enables direct download via the browser?
I sent a google drive link, does this work ?
Yes that worked! I have downloaded the data and will take a closer look next week.
Hi @Nal44 , sorry to take a bit longer to follow up, had a busy few weeks. I am back working on the tool this week and will try to follow up here by the end of the week.
Hi , Sounds good , thanks for the update :) antho
There are different ways for refining existing masks with
micro_sam
.The easiest option would be to derive point prompts from the centers of the masks and then prompt the model with these points.
The function batched_inference can be used for this.
Here is some (non-tested!) code for this, using skimage to derive the point prompts.
Another possible strategies is to derive bounding boxes from the segmented objects and use these for prompts instead. This could be done by passing the boxes argument.
Note that this code will only for 2D. It is possible to extend this to 3D, but I would suggest to start in 2D first and once this is working well I can give hints for how to extend it to 3D.
cc @Nal44