facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.67k stars 1.19k forks source link

Request for Python version of connected_components and Performance Comparison Inquiry #243

Open Spark001 opened 3 months ago

Spark001 commented 3 months ago

Hi @ronghanghu

In the previous issues, many people mentioned the installation problems with _C, such as #53 #59 #22 , and I have encountered similar issues:

UserWarning: cannot import name '_C' from 'sam2' 

Skipping the post-processing step due to the error above. 
You can still use SAM 2 and it's OK to ignore the error above, 
although some post-processing functionality may be limited 
(which doesn't affect the results in most cases; see 
https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).

Although it can run ignoring this warning, I still want to understand the impact of this post-processing on the results.

By observing the files in csrc, it was found that there is a function similar to the union-find algorithm called connected_components. Could you provide a pure Python version of this function? This would be more installation-friendly, especially for people who are not convenient to update the CUDA driver : ).

Or is there a significant difference in time consumption between the CUDA-based implementation and the Python-based implementation? How much exactly is it? In which scenarios would there be a larger difference?

Looking forward to your reply, thank you very much.

ronghanghu commented 3 months ago

Hi @Spark001, thanks for your interest in SAM 2. Regarding your questions:

Although it can run ignoring this warning, I still want to understand the impact of this post-processing on the results.

In most scenarios, this post-processing step doesn't make a notable difference. It is intended to cover the (relatively rare) case of removing (filling) some small holes in the output masks.

Could you provide a pure Python version of this function?

For a pure Python version, it's mostly similar to cv2.connectedComponentsWithStats in OpenCV (see https://www.geeksforgeeks.org/python-opencv-connected-component-labeling-and-analysis/) or skimage.measure.label in Scikit-Image (see https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.label). For the latter, there is a community implementation of it in https://github.com/facebookresearch/segment-anything-2/pull/216 that you may have a try with :)

Or is there a significant difference in time consumption between the CUDA-based implementation and the Python-based implementation? How much exactly is it? In which scenarios would there be a larger difference?

Yes, running this post-processing step is usually much slower on CPUs than using the CUDA-based implementation (and it could be a major overhead in video applications, so it's recommended to compile the CUDA kernel above for connected components or to just skip it, instead of running a CPU version of this op.

Spark001 commented 3 months ago

@ronghanghu Thanks for your quick reply 👍

I will try to find a way to compile it.