opengeos / segment-geospatial

A Python package for segmenting geospatial data with the Segment Anything Model (SAM)
https://samgeo.gishub.org
MIT License
3k stars 310 forks source link

Simplify Mask Handling When Only One Object is Detected #349

Open ro-hit81 opened 3 days ago

ro-hit81 commented 3 days ago

Description

When using sam2_models, I encountered an issue related to processing images that contain only one detected object. Specifically, the current code may raise the following error:

ValueError: cannot select an axis to squeeze out which has size not equal to one

This issue arises when the masks are squeezed without proper dimension checks.

I propose enhancing the logic to handle cases where only one object is detected by implementing additional checks before squeezing the masks. This improvement will ensure that the code executes smoothly regardless of the number of detected objects.

https://github.com/opengeos/segment-geospatial/blob/8888ba6471cedc22ba790601b058cd33d5544257/samgeo/text_sam.py#L338

if len(boxes) > 0:
            masks = self.predict_sam(image_pil, boxes)
            print("Masks shape before squeeze:", masks.shape)

            # Squeeze masks based on their shape
            if masks.ndim == 4 and masks.shape[1] == 1:
                masks = masks.squeeze(1)  # Squeeze the channel dimension
            elif masks.ndim == 3 and masks.shape[0] == 1:
                masks = masks.squeeze(0) 

By implementing this suggested change, we can prevent errors related to dimension squeezing when processing images with single object detections. Additionally, since masks.ndim will either be 4 or 3, we can simplify the code by removing the elif statement that checks for masks.ndim == 3 and masks.shape[0] == 1.

giswqs commented 2 days ago

Good suggestion. I encourage to submit a pull request to fix this. Thanks.