tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
7.8k stars 372 forks source link

Full compatibility with Segment Anything #1544

Open astnmsn opened 3 months ago

astnmsn commented 3 months ago

I have not found any equivalent request

Feature description

Addition of support for necessary operators to utilize the vit_b SAM Model which can be found here:

Pth Direct download

ONNX file sam_vit_b.onnx.zip

I have inspected the model using Netron and compared the nodes to the support list here

Below is the list of operators in the network that are partially or fully missing from the support list Operator (missing support)

Feature motivation

I am currently trying to load and run an onnx model segment-anything inside an image editing app in an effort to provide a masking experience similar to the demo available here. I am integrating it into an existing rust codebase using wgpu that compiles to wasm and runs in the browser.

Before I can select burn as the ml library to support this workflow, I need to be sure that it supports the operators specified in the model.

antimora commented 3 months ago

Thanks for filing this. This is helpful as we prioritize ONNX ops. If you have a direct link to the ONNX file, can you also link this?

antimora commented 3 months ago

Updating Expand to Import, since we just added this op. I need to update the docs.

antimora commented 3 months ago

Submitted a PR to fix the supported OPs document: https://github.com/tracel-ai/burn/pull/1547

astnmsn commented 3 months ago

This is the model download link provided by the SAM repo - I have also added to the original post

antimora commented 3 months ago

I think it might be faster to implement the model manually in Burn and load the pth weights file, which we now support.

You can check out an existing model to see how it's done: https://github.com/tracel-ai/models/tree/main/resnet-burn

We also have a YOLOX object detection PR in the works: https://github.com/tracel-ai/models/pull/24

@laggui has written a great tutorial on this subject: https://dev.to/laggui/transitioning-from-pytorch-to-burn-45m

Recently, we made tons of enhancements to the PyTorchFileRecorder: https://discord.com/channels/1038839012602941528/1144670451763785769/1216788417984335872

image

@laggui, @nathanielsimard, @ashdtu, would this be worth implementing ourselves? Should we move this ticket to the models repo?

laggui commented 3 months ago

The community is always one step ahead 😄

We've actually discussed adding SAM to our models and this was in the plans following the release.

We still haven't decided whether we want to reimplement it and use the PyTorch file recorder to import the weights or use the ONNX import.

antimora commented 3 months ago

@laggui, if we decide to work on this, I am more inclined to adding ONNX OPs. It will be biggest bang for the buck instead of spending time to come up with the model by hand (although I am not sure how complex it is).

laggui commented 2 months ago

Btw, not sure if anyone has delved into the SAM code for ONNX export but it doesn't include all the operations to actually run the model for an input image. The encoder part is totally left out of the ONNX export and the exported ONNX model expects image embeddings as input.

In their example they still use their pytorch implementation to provide the embeddings to the ONNX runtime.

So even if we support the missing operations in this issue, SAM support will still not be complete. Is this what you expected @astnmsn?

astnmsn commented 2 months ago

@laggui Thanks for asking, and yes that is expected. We plan to run the first half of the model to generate the embeddings on the backend using pytorch. Only the second half, which produces the masks from the embeddings and the cursor/click positions, will be run on the client

antimora commented 2 months ago

Regarding Tile Op. We need to rename our current repeat op to repeat_dim and implement a proper repeat for all dimensions at once.

antimora commented 2 months ago

Resolving this ticket will resolve https://github.com/tracel-ai/burn/issues/1560 as well.

laggui commented 2 months ago

Current state of required ops based on the latest PRs:

op_type Burn Import
Add
Cast
Concat
Constant
ConstantOfShape
Conv
ConvTranspose
Cos
Div
Equal
Erf
Expand
Floor
Gather
Gemm
LayerNormalization ✔️
MatMul ✔️
Mul
Not ✔️
OneHot
Pow
Reciprocal
ReduceMax ✔️
ReduceMean ✔️
Relu
Reshape
Resize
Shape ✔️
Sin ✔️
Slice
Softmax
Sqrt
Sub
Tile
Transpose
Unsqueeze
Where ✔️