SysCV / sam-hq

Segment Anything in High Quality [NeurIPS 2023]
https://arxiv.org/abs/2306.01567
Apache License 2.0
3.66k stars 220 forks source link

Alternative implementation in Refiners #127

Closed hugojarkoff closed 4 months ago

hugojarkoff commented 6 months ago

Hello everyone, and thank you for the fantastic work!

We are building Refiners, an open source, PyTorch-based micro-framework made to easily train and run adapters on top of foundational models. Just wanted to let you know that HQ-SAM is now natively supported on top of our SAM implementation!

A MWE in Refiners (similar to demo_hqsam.py) would look like this:

download_sam() download_hq_sam() convert_sam() convert_hq_sam()


- Finally, run the snippet below to do some inference using HQ-SAM:
```python
import torch
from PIL import Image

from refiners.fluxion.utils import load_from_safetensors, tensor_to_image
from refiners.foundationals.segment_anything import SegmentAnythingH
from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter

# Instantiate SAM model
sam_h = SegmentAnythingH(
    device=torch.device("cuda"),
    dtype=torch.float32,
    multimask_output=False,  # Multi-mask output is not supported by HQ-SAM
)
sam_h.load_from_safetensors("tests/weights/segment-anything-h.safetensors")

# Instantiate HQ-SAM adapter, with downloaded and converted weights
hq_sam_adapter = HQSAMAdapter(
    sam_h,
    hq_mask_only=True,
    weights=load_from_safetensors("tests/weights/refiners-sam-hq-vit-h.safetensors"),
)

# Patch SAM with HQ-SAM by “injecting” the adapter
hq_sam_adapter.inject()

# Define the image to segment and the prompt
tennis_image = Image.open("tests/foundationals/segment_anything/test_sam_ref/tennis.png")
box_points = [[(4, 13), (1007, 1023)]]

# Run inference
high_res_masks, *_ = sam_h.predict(input=tennis_image, box_points=box_points)

predicted_mask = tensor_to_image(high_res_masks)
predicted_mask.save("predicted_mask.png")

You should now have generated the following mask (note: the image has been downsized by 50% in postprocessing to fit on GitHub):

predicted_mask

A few more things:

Feedback welcome!