deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 658 forks source link

Request for SAM2 Model in ONNX Format and Synset File for Classification Output #3483

Open leleZeng opened 1 month ago

leleZeng commented 1 month ago

Description

Hello DJL Team,

I am currently using the SAM2 model from the DJL ModelZoo for inference purposes. However, I have encountered a couple of limitations that I would like to request your support in resolving:

ONNX Model Format: Currently, the SAM2 model is only available in the PyTorch (.pt) format. For broader compatibility and ease of integration with other frameworks, I would like to request that an ONNX (.onnx) format version of the SAM2 model be added to the ModelZoo.

Classification Output and Synset File: It appears that the SAM2 model's inference results do not provide classification outputs. To facilitate classification tasks, could you please provide a synset.txt file or any relevant label mapping file for this model, so that users can easily interpret the classification results?

These features would greatly enhance the usability and flexibility of the SAM2 model in a variety of projects.

Thank you for your work and consideration!

frankfliu commented 1 month ago

@leleZeng I created a PR to add Onnx model support: https://github.com/deepjavalibrary/djl/pull/3492. Please try 0.31.0-SNAPSHOT for onnx sam2 model from our model zoo.

SAM2 mode is not a classification model, it won't generate the classes (the model doesn't know what the object is). I can only segment the object from background.

leleZeng commented 1 month ago

Thank you for adding support for SAM2's ONNX model. The demo effect is very good.

I have an additional question: how did you manage to split the model into two parts, one for encoding (sam2-hiera-large), using the approach of tracing the model? Could you explain the steps or method you used to achieve this separation?

frankfliu commented 1 month ago

You can take a look at this python code: https://github.com/deepjavalibrary/djl/blob/master/examples/docs/trace_sam2_img.py#L37

You can create a wrapper model, it contains two method: encode() and decode() (mapped to forward)

leleZeng commented 1 month ago

@frankfliu The segmentation effect is very good,but I’m a bit unsure about how to implement object tracking across video frames. I would like to understand if it’s possible to track objects using SAM2, particularly in consecutive video frames, and how to maintain consistent identification of the same object over time.  Specifically, I have the following questions:  Does SAM2 provide any built-in object tracking functionality, or is it necessary to combine it with other algorithms (such as Kalman filters, optical flow, etc.) to achieve tracking? If additional algorithms are needed, do you have any recommended approaches or example code for integration? Are there any performance optimization tips for handling real-time video streams? Thank you for your help!

leleZeng commented 1 month ago

This demo is an object tracking implementation using Sam2, link.