pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.17k stars 850 forks source link

grpc batch serving, inference.proto #2106

Open ndvbd opened 1 year ago

ndvbd commented 1 year ago

When running the torch-model-archiver, is there a way to specify the inference.proto file? Otherwise, when looking at https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/inference.proto, I see: map<string, bytes> input = 3; //required So, if I want to send multiple sentences (batch) using grpc, how do I do that if I can't change the proto file to: repeated string input = 3; //required ?

On the one hand the inference.proto defines the structure as a map<string, bytes> , but on the other hand the file: https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py accepts in the preprocess() code (believe there a type in the function docs?) a list of map<string, bytes>.

How can it be?

lxning commented 1 year ago

@ndvbd TorchServe supports dynamic batching which means TorchServe automatically aggregates inference requests from client side together in the frontend, and then sends a batch of requests to backend. That's why you see the code.

It is fine if you want to manually batch on client side and send the batch to TorchServe. However, TorchServe still treats the request as a single request and send the single request to backend. In this case, users have to customize the backend to decode the single request by themself. In other words, you need further split the input_text by yourself.

ndvbd commented 1 year ago

Thank you for the answer. I suggest a better approach/improvement for torchserve. Let's assume the server can handle a max batch size of 80. I suggest to have "batches of batches". So instead of a single client that has 50 predictions to make, have to send 50 grpc requests (50*the latency), that the client can simply send all the 50, as one "batch" to torchserve, and the torchserve will take these 50 from this client, take another 30 from another client, and run the 80 together. So, I suggest to change the inference.proto from map<string,bytes> to repeated map<string,bytes> or something like that. Of course we can "pack"/serialize/marshall the 50 requests ourselves, but then we cant add the 30 from another client, and we lose potenial here (because then the torchserve is not "aware" of how many tensors are in the serialized data).