ChaoningZhang / MobileSAM

This is the official code for MobileSAM project that makes SAM lightweight for mobile applications and beyond!
Apache License 2.0
4.86k stars 505 forks source link

Image embedding in ONNX model #81

Closed chaitanyakrishna1248 closed 1 year ago

chaitanyakrishna1248 commented 1 year ago

Why can't we include the image embedding part in the onnx model? Any specific reason to compute the embeddings beforehand before sending it to the onnx model. Is it even possible to include in the onnx model?

GuichardVictor commented 1 year ago

I have attempted to export the encoder to onnx but it did not work very well.

I have modified the SamPredictor class:

...

class SamPredictor(nn.Module):
    ...
    @torch.no_grad()
    def forward(
        self,
        transformed_image: torch.Tensor,
    ):
        input_image = self.model.preprocess(transformed_image)
        features = self.model.image_encoder(input_image)

        return features

Here is the code to export the encoder:

import torch
import onnxruntime as ort
from mobile_sam import sam_model_registry, SamPredictor
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
import numpy as np

model_type = "vit_t"
sam_checkpoint = "./weights/mobile_sam.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam.to(device=device)
mobile_sam.eval()

predictor = SamPredictor(mobile_sam)

image = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    predictor, 
    (image,),
    "model.onnx",
    input_names=["image"],
    # opset_version=16,
    output_names=["features"],
    do_constant_folding=False,
    dynamic_axes={"image": {2: "height", 3: "width"}}
)

forward_features = predictor(image).clone().numpy()

session = onnxruntime.InferenceSession("model.onnx")
onnx_features = session.run(None, {
    "image": image.numpy()
})[0]

onnx_ok = np.allclose(default_features, onnx_features)
print(onnx_ok) # False

print(
    torch.nn.functional.l1_loss(
        torch.as_tensor(forward_features),
        torch.as_tensor(onnx_features),
    ).item()
) # ~0.04

As you can see while I exported the encoder, it does not produce the same output as in pytorch.

Maybe someone could help on this.

7 have successfully exported the encoder and the decoder, you can take a look at this

chaitanyakrishna1248 commented 1 year ago

Alright thanks! I will look at that.

ryouchinsa commented 1 year ago

You can find the converter for the encoding part to the onnx format in the repository Segment Anything CPP Wrapper.

Here is the Segment Anything CPP Wrapper for macOS. This code is originated from Segment Anything CPP Wrapper and implemented on macOS app RectLabel. We customized the original code so that getMask() uses the previous mask result called as low_res_logits and retain the previous mask array for undo/redo actions.

We hope this macOS version would help to develop the iOS version. Please let us know your opinion.

sam_polygon

GuichardVictor commented 1 year ago

I can confirm that my script does work as I have used it for a personal project @chaitanyakrishna1248 you can close this issue :)

chaitanyakrishna1248 commented 1 year ago

Thank you, I Will try this!