Open lxning opened 1 year ago
Also many models on video tasks need to maintain a per sequence state memory.
@bhack Is the backend cache (see stateful cache example) able to cover your use case "a per sequence state memory"?
Do you have any doc/markdown related to the stateful cache? I was talking about quite classical models with internal short and long term memories in different flavor like: https://github.com/hkchengrex/XMem https://github.com/hkchengrex/Cutie https://github.com/yoxu515/aot-benchmark etc..
@bhack the current proposal mostly is focused on caching the previous frame/ segment state/ prediction, this is more of serving scenario, in terms of classic short/long term memory I believe they mostly are defined as part of the model, would love to learn more about your thought/ suggestions to see if we can accomodate that as well.
Yes these memories are defined in the model. I don't know what kind of serve optimization is possible but I suppose that in the worst case you need to maintain a session id between clients and inference serve to know the model instance to route on the client requests or you are going to totally lose the model state. Some of these models generally have a memory reset call so probably an inference hook to reset the state and kick-off a new session it will be useful for sure without wasting model unload and reload overhead. This needs to be part of the communication protocol between the client and the server.
@bhack the "sequence_id" is introduced in the "PredictionsRequest" protocol. It is used for TorchServe to know which backend worker should be chosen to serve this request.
What about a hook to control the state reset? Generally these models expose a function to reset the state.
@bhack Maybe we can add a control function in handler_utils if this function can be generalized. Then it can be used in custom handler in backend layer. We can work together to add a video stateful inference example if you are interested in.
Yes it could be useful
@bhack Maybe we can add a control function in handler_utils if this function can be generalized. Then it can be used in custom handler in backend layer. We can work together to add a video stateful inference example if you are interested in.
/cc @hkchengrex @yoxu515 in the case they want to share with us their feedback.
I think that also @nikitakaraevv could be interested to share a feedback on this for the new Co-tracked version that could inference video chunks https://github.com/facebookresearch/co-tracker/issues/37#issuecomment-1769491730
Any update on this?
New Meta model sam2 has very similar needs: https://github.com/facebookresearch/segment-anything-2
@bhack https://sam2.metademolab.com/demo is backed by TorchServe. :-)
Interesting, how it is managed the inference_state
?
I suppose the problem it is quite similar to the other models we have discussed in this threads.
that's caching part of SAM2. Check https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb
I've seen this but I don't understand how it is going to work in a multi-users setup with torchserve.
Is it going to exchange the session/state forth - and back with every client? How they are going to call the reset_state
concurrently?
Both GRPC and HTTP are able to support sticky session. This design is able to guarantees each client session has a dedicated backend worker. The state of the session is handled by model itself. You can take a look at this example: https://github.com/pytorch/serve/tree/master/examples/stateful/sequence_continuous_batching.
Interesting, but it could be harder to do it with a standard exported model which is internally stateful. So it seems we need to have something similar by design.
In the meantime do you think that we could have a small example for this SAM2 like solution as I think it could better than nothing?
@mreso could you please check if we can provide a SAM2 example for CX?
@mreso could you please check if we can provide a SAM2 example for CX?
It seems that users are going to have issues exporting the video mode for onnx: https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything/issues/6
So any example on how to serve it in torchserve it will be very useful.
🚀 The feature
Author: Li Ning
Background
A stateful model possesses the ability to detect interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.
TorchServe is a stateless model server. It treats each inference request as independent, and does not maintain the state of inference requests. A stateful model requests TorchServe to be extended to a stateful model server, which is generally able to
Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID when a connection resumes as long as the sequence is not expired on the TorchServe side.
The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by maxSequenceJobQueueSize. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.
Requirements Scope
To support a stateful model, the requirements for TorchServe are scoped as the following.
Design
The above picture shows the architecture changes to support stateful inference.
API Layer
Streaming applications usually use HTTP, GRPC and Kafka to transfer messages. This design only discusses GRPC stream since SageMaker does not support HTTP request streaming at this moment; Kafka or similar messaging system requests applications to support it.
A new endpoint
StreamPredictions2
is introduced for sequence batching. Thesequence_id
defined inPredictionsRequest
andPredictionResponse
is similar as atopic
in Kafka messaging system. TorchServe routes the inference requests to a specific worker based on thesequence_id
.Core Layer
There is only one jobQueue to store incoming inference requests in the existing TorchServe. Each worker of a model has a batcher to poll a batch of jobs from the job queue.
Stateful inference requires a sequence of inference requests to be routed to the same worker. A single jobQueue is not able to separate the jobs from the different sequences. There is a new concept "job group" introduced in this design. A job group has a single job queue storing the jobs from the same sequence. A batcher of a worker continuously polls a set of job groups; meanwhile it concurrently polls a job from each group. It ensures that each request within a batch is from a distinct sequence. A job group (ie. a sequence) is removed if there is no new job from this job group once max_idle_milliseconds is reached.
Backend Layer
User chooses a caching solution in the customized handler to store and fetch the inference state based on inference sequenceId. There is a separate TorchServe Cache RFC to cover this part.
Motivation, pitch
A stateful model possesses the ability to detect interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.
Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This sequence ID serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID when a connection resumes as long as the sequence is not expired on the TorchServe side.
Alternatives
No response
Additional context
No response