SeldonIO / MLServer

An inference server for your machine learning models, including support for multiple frameworks, multi-model serving and more
https://mlserver.readthedocs.io/en/latest/
Apache License 2.0
685 stars 177 forks source link

Inference streaming support #1750

Closed RobertSamoilescu closed 3 months ago

RobertSamoilescu commented 3 months ago

This PR includes streaming support for MLServer by allowing the user to implement in the runtime the predict_stream method which expects as input a async generator of request an outputs a async generator of response.

class MyModel(MLModel):

    async def predict(self, payload: InferenceRequest) -> InferenceResponse:
        pass

    async def predict_stream(
        self, payloads: AsyncIterator[InferenceRequest]
    ) -> AsyncIterator[InferenceResponse]:
        pass

While the input-output types for the predict remain the same, for the predict_stream the implementation can handle a stream of inputs and a stream of outputs. This design choice is quite general and can cover many input-output scenarios:

Although for REST, streamed input might not be a thing and currently not supported, for gRPC it is quite natural to have. In the case that a user will like to use streamed inputs, then they will have to use gRPC.

Exposed endpoints

We expose the following endpoints (+ the ones including the version) to the user:

The first two are general purpose endpoints while the later two are LLM specific (see open inference protocol here). Note that the infer and generate endpoints will point to the infer implementation while infer_stream and generate_stream will point to infer_stream implementation defined above.

Client calls

REST non-streaming

import os
import requests
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

api_url = "http://localhost:8080/v2/models/text-model/generate"
response = requests.post(api_url, json=inference_request.dict())
response = types.InferenceResponse.parse_raw(response.text)
print(StringCodec.decode_output(response.outputs[0]))

REST streaming

import os
import httpx
from httpx_sse import connect_sse
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

with httpx.Client() as client:
    with connect_sse(client, "POST", "http://localhost:8080/v2/models/text-model/generate_stream", json=inference_request.dict()) as event_source:
        for sse in event_source.iter_sse():
            response = types.InferenceResponse.parse_raw(sse.data)
            print(StringCodec.decode_output(response.outputs[0]))

gRPC non-streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)
grpc_channel = grpc.insecure_channel("localhost:8081")
grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
response = grpc_stub.ModelInfer(inference_request_g)

response = ModelInferResponseConverter.to_types(response)
print(StringCodec.decode_output(response.outputs[0]))

gRPC streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)

async def get_inference_request_stream(inference_request):
    yield inference_request

async with grpc.aio.insecure_channel("localhost:8081") as grpc_channel:
    grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
    inference_request_stream = get_inference_request_stream(inference_request_g)

    async for response in grpc_stub.ModelStreamInfer(inference_request_stream):
        response = ModelInferResponseConverter.to_types(response)
        print(StringCodec.decode_output(response.outputs[0]))

Limitations