pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.15k stars 837 forks source link

[RFC] Sequence Batching for Stateful Inference #2743

Open lxning opened 10 months ago

lxning commented 10 months ago

🚀 The feature

Author: Li Ning

Background

A stateful model possesses the ability to detect interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.

TorchServe is a stateless model server. It treats each inference request as independent, and does not maintain the state of inference requests. A stateful model requests TorchServe to be extended to a stateful model server, which is generally able to

Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID when a connection resumes as long as the sequence is not expired on the TorchServe side.

The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by maxSequenceJobQueueSize. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.

stateful_batchi

Requirements Scope

To support a stateful model, the requirements for TorchServe are scoped as the following.

  1. GRPC stream
    • A sequence of inference requests is sent to TorchServe as a continuous GRPC stream.
    • Client sends each single inference request in one GRPC request in GRPC mode.
    • A sequence can not idle more than X milliseconds.
    • A sequence of inference requests' responses is sent to the client as a continuous GRPC stream. Server sends each inference request's response in one GRPC response in GRPC mode.
  2. A sequence of inference requests is associated with the same sequence id string.
  3. A stateful model configs max_idle_milliseconds. TorchServe monitors if there is an idle timeout in a sequence of inference requests.
  4. The max_number_sequence is the max number of sequences can be accepted, which is equal to or larger than the batch size * # workers. Each inference request of a batch is from a different sequence.
  5. User maintains the inference state in a customized handler.

Design

stateful-ds

The above picture shows the architecture changes to support stateful inference.

API Layer

Streaming applications usually use HTTP, GRPC and Kafka to transfer messages. This design only discusses GRPC stream since SageMaker does not support HTTP request streaming at this moment; Kafka or similar messaging system requests applications to support it.

A new endpoint StreamPredictions2 is introduced for sequence batching. The sequence_id defined in PredictionsRequest and PredictionResponse is similar as a topic in Kafka messaging system. TorchServe routes the inference requests to a specific worker based on thesequence_id.

message PredictionsRequest {
    // Name of model.
    string model_name = 1; //required

    // Version of model to run prediction on.
    string model_version = 2; //optional

    // Input data for model prediction
    map<string, bytes> input = 3; //required

    // SequenceId is required for StreamPredictions2 API.
    optional string sequence_id = 4; //optional
}

message PredictionResponse {
    // Response content for prediction
    bytes prediction = 1;

    // SequenceId is required for StreamPredictions2 API.
    optional string sequence_id = 2; //optional

    // Error information for StreamPredictions2 API.
    optional google.rpc.Status status = 3; //optional
}

service InferenceAPIsService {
    // Check health status of the TorchServe server.
    rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {}

    // Predictions entry point to get inference using default model version.
    rpc Predictions(PredictionsRequest) returns (PredictionResponse) {}

    // Streaming response for an inference request.
    rpc StreamPredictions(PredictionsRequest) returns (stream PredictionResponse) {}

    // Bi-direction streaming inference and response.
    rpc StreamPredictions2(stream PredictionsRequest) returns (stream PredictionResponse) {}
}

Core Layer

There is only one jobQueue to store incoming inference requests in the existing TorchServe. Each worker of a model has a batcher to poll a batch of jobs from the job queue.

Stateful inference requires a sequence of inference requests to be routed to the same worker. A single jobQueue is not able to separate the jobs from the different sequences. There is a new concept "job group" introduced in this design. A job group has a single job queue storing the jobs from the same sequence. A batcher of a worker continuously polls a set of job groups; meanwhile it concurrently polls a job from each group. It ensures that each request within a batch is from a distinct sequence. A job group (ie. a sequence) is removed if there is no new job from this job group once max_idle_milliseconds is reached.

stateful-addjob

Backend Layer

User chooses a caching solution in the customized handler to store and fetch the inference state based on inference sequenceId. There is a separate TorchServe Cache RFC to cover this part.

Motivation, pitch

A stateful model possesses the ability to detect interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.

Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This sequence ID serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the sequence ID when a connection resumes as long as the sequence is not expired on the TorchServe side.

Alternatives

No response

Additional context

No response

bhack commented 10 months ago

Also many models on video tasks need to maintain a per sequence state memory.

lxning commented 10 months ago

@bhack Is the backend cache (see stateful cache example) able to cover your use case "a per sequence state memory"?

bhack commented 10 months ago

Do you have any doc/markdown related to the stateful cache? I was talking about quite classical models with internal short and long term memories in different flavor like: https://github.com/hkchengrex/XMem https://github.com/hkchengrex/Cutie https://github.com/yoxu515/aot-benchmark etc..

HamidShojanazeri commented 10 months ago

@bhack the current proposal mostly is focused on caching the previous frame/ segment state/ prediction, this is more of serving scenario, in terms of classic short/long term memory I believe they mostly are defined as part of the model, would love to learn more about your thought/ suggestions to see if we can accomodate that as well.

bhack commented 10 months ago

Yes these memories are defined in the model. I don't know what kind of serve optimization is possible but I suppose that in the worst case you need to maintain a session id between clients and inference serve to know the model instance to route on the client requests or you are going to totally lose the model state. Some of these models generally have a memory reset call so probably an inference hook to reset the state and kick-off a new session it will be useful for sure without wasting model unload and reload overhead. This needs to be part of the communication protocol between the client and the server.

lxning commented 10 months ago

@bhack the "sequence_id" is introduced in the "PredictionsRequest" protocol. It is used for TorchServe to know which backend worker should be chosen to serve this request.

bhack commented 10 months ago

What about a hook to control the state reset? Generally these models expose a function to reset the state.

lxning commented 10 months ago

@bhack Maybe we can add a control function in handler_utils if this function can be generalized. Then it can be used in custom handler in backend layer. We can work together to add a video stateful inference example if you are interested in.

bhack commented 10 months ago

Yes it could be useful

bhack commented 10 months ago

@bhack Maybe we can add a control function in handler_utils if this function can be generalized. Then it can be used in custom handler in backend layer. We can work together to add a video stateful inference example if you are interested in.

/cc @hkchengrex @yoxu515 in the case they want to share with us their feedback.

I think that also @nikitakaraevv could be interested to share a feedback on this for the new Co-tracked version that could inference video chunks https://github.com/facebookresearch/co-tracker/issues/37#issuecomment-1769491730

bhack commented 7 months ago

Any update on this?

bhack commented 1 month ago

New Meta model sam2 has very similar needs: https://github.com/facebookresearch/segment-anything-2

lxning commented 1 month ago

@bhack https://sam2.metademolab.com/demo is backed by TorchServe. :-)

bhack commented 1 month ago

Interesting, how it is managed the inference_state? I suppose the problem it is quite similar to the other models we have discussed in this threads.

lxning commented 1 month ago

that's caching part of SAM2. Check https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/video_predictor_example.ipynb

bhack commented 1 month ago

I've seen this but I don't understand how it is going to work in a multi-users setup with torchserve. Is it going to exchange the session/state forth - and back with every client? How they are going to call the reset_state concurrently?

lxning commented 1 month ago

Both GRPC and HTTP are able to support sticky session. This design is able to guarantees each client session has a dedicated backend worker. The state of the session is handled by model itself. You can take a look at this example: https://github.com/pytorch/serve/tree/master/examples/stateful/sequence_continuous_batching.

bhack commented 1 month ago

Interesting, but it could be harder to do it with a standard exported model which is internally stateful. So it seems we need to have something similar by design.

In the meantime do you think that we could have a small example for this SAM2 like solution as I think it could better than nothing?

lxning commented 1 month ago

@mreso could you please check if we can provide a SAM2 example for CX?

bhack commented 3 weeks ago

@mreso could you please check if we can provide a SAM2 example for CX?

It seems that users are going to have issues exporting the video mode for onnx: https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything/issues/6

So any example on how to serve it in torchserve it will be very useful.