heyoeyo / muggled_sam

Muggled SAM: Segmentation without the magic
Apache License 2.0
54 stars 7 forks source link

How did you make it real-time? #1

Closed charchit7 closed 3 months ago

charchit7 commented 3 months ago

Hey, nice work; I wanted to know how you made it in real time. Did you try converting the model to onnx format?

heyoeyo commented 3 months ago

Thanks!

The mask generating capability of the model is already real time, there aren't any significant optimizations in this repo (yet). This is part of the original SAM model design and is possible because the image encoding and mask generating happen on separate models. The image encoding is slow, but only happens once on start-up, whereas the mask generating (and prompt encoding) is fast and can run in real time on every change to the input prompts (i.e. can update to match mouse hovering).

I haven't tried exporting to onnx yet, since I still need to clean up some of the model implementation details. However, I do plan to make an export script in the near future. For now, if you wanted onnx versions, you can create them with code like:

import torch
from lib.make_sam import make_sam_from_state_dict

# Export the image-encoder component of SAM
_, sammodel = make_sam_from_state_dict("/path/to/model.pt")
torch.onnx.export(
    model=sammodel.image_encoder,
    args=torch.randn((1, 3, 1024, 1024), dtype=torch.float32),
    f="sam_image_encoder.onnx",
    input_names=["input"],
    output_names=["output"],
    export_params=True,
    do_constant_folding=True,
    opset_version=14,
)

This just creates an onnx version of the image encoder (not the whole model), and it currently gives a bunch of warnings, but it is usable. You can test the onnx model with code like:

import cv2
import numpy as np
import onnxruntime

# Load image, size to 1024x1024 and convert to RGB
img_bgr = cv2.imread("/path/to/image.jpg")
img_onnx = cv2.resize(img_bgr, dsize=(1024, 1024))
img_onnx = cv2.cvtColor(img_onnx, cv2.COLOR_BGR2RGB)

# Apply model RGB normalization make shape: 1x3x1024x1024
mean_std = np.float32([58.395, 57.12, 57.375])
mean_rgb = np.float32([123.675, 116.28, 103.53])
img_onnx = (np.float32(img_onnx) - mean_rgb) / mean_std
img_onnx = img_onnx.swapaxes(0, 2)[None, :]
onnx_input = {"input": img_onnx}

# Run the onnx model
ort_session = onnxruntime.InferenceSession("sam_image_encoder.onnx")
onnx_output = ort_session.run(None, onnx_input)
print("Num outputs:", len(onnx_output))
print("Example output shape:", onnx_output[0].shape, "DType:", onnx_output[0].dtype)

At some point I'll make a more convenient version that builds the image pre-processing directly into the model and have exports for other model components.

charchit7 commented 3 months ago

Thank you so much for the detailed answer. @heyoeyo

charchit7 commented 3 months ago

Hey, @heyoeyo, btw, the above code works for both, right? SAM1 and SAM2?

heyoeyo commented 3 months ago

Yes, the image encoder of both SAM v1 & v2 take the same inputs and the make_sam_from_state_dict function will figure out which version (and size) to load based on the model weights that are given, so the code should work for either version and any size model.

However, the outputs of the v1 & v2 models are different. V1 outputs a single tensor, while v2 outputs a list of 3 differently sized tensors (and 3 positional embeddings, though I'll be removing that in a future update). So the resulting onnx models are not interchangeable.