triton-inference-server / pytriton

PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
https://triton-inference-server.github.io/pytriton/
Apache License 2.0
719 stars 50 forks source link

Binary output truncated... #34

Closed rilango closed 11 months ago

rilango commented 1 year ago

Description

When a large array is converted to npz format and returned, the client received a truncated result.

To reproduce

If relevant, add a minimal example so that we can reproduce the error, if necessary, by running the code. For example:

# server
import io
import h5py
import numpy as np

from typing import Dict

from pytriton.decorators import batch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton

class TritonInferModel:

    def __init__(self) -> None:
        self._model_name = 'TestImpl'

        # Encoder
        self._enc_ip = (Tensor(name="input_string", shape=(-1,), dtype=bytes),
                        Tensor(name="format", shape=(1, ), dtype=bytes),)

        self._enc_op = (Tensor(name="embedding", shape=(1, ), dtype=bytes),)

    def _to_npz(self, **kwargs):
        with io.BytesIO() as output:
            np.savez(output, **kwargs)
            return output.getvalue()

    def _to_h5(self, **kwargs):
        with io.BytesIO() as output:
            with h5py.File(output, 'w') as h5f:
                for label, data in kwargs.items():
                    h5f.create_dataset(label, data=data)
            return output.getvalue()

    @batch
    def encoder(self, **inputs: np.ndarray) -> Dict[str, np.ndarray]:
        embeddings={'embeddings': np.random.rand(1, 1, 512)}

        format = np.char.decode(inputs.pop("format").astype("bytes"), encoding="utf-8")
        format = [np.char.decode(p.astype("bytes"), "utf-8").item() for p in format][0]
        if format == 'npz':
            output = self._to_npz(**embeddings)
        elif format == 'h5':
            output = self._to_h5(**embeddings)
        else:
            raise Exception(f'Unknown format {format}')

        print("=================> size", len(output), np.array([[output]], dtype=bytes).shape)

        return {"embedding": np.array([[output]], dtype=bytes), }

    def start(self, triton):
        print(f"Loading {self._model_name} encoder...")

        triton.bind(
            model_name=f"embeddings",
            infer_func=self.encoder,
            inputs=self._enc_ip,
            outputs=self._enc_op,
            config=ModelConfig(max_batch_size=8),
            strict=True,
        )

def main():
    with Triton() as triton:
        inferer = TritonInferModel()
        inferer.start(triton)
        triton.serve()

if __name__ == "__main__":
    main()
# client
import io
import h5
import numpy as np
import tritonclient.http as httpclient

from tritonclient.utils import *

encoder_model = "embeddings"

test_input = ['Input_1', 'Input_2', 'Input_3']

def test_encoder_n_decoder():
    with httpclient.InferenceServerClient("localhost:8000") as client:

        format = 'npz'
        smis_input = np.array([test_input]).astype(bytes)
        format_input = np.array([[format]]).astype(bytes)
        inputs = [
            httpclient.InferInput("input_string", smis_input.shape,
                                np_to_triton_dtype(smis_input.dtype)),
            httpclient.InferInput("format", format_input.shape,
                                np_to_triton_dtype(format_input.dtype)),
        ]

        inputs[0].set_data_from_numpy(smis_input)
        inputs[1].set_data_from_numpy(format_input)

        outputs = [
            httpclient.InferRequestedOutput("embedding"),
        ]

        response = client.infer(encoder_model,
                                inputs,
                                request_id=str(1),
                                outputs=outputs)

        result = response.get_response()
        embeddings = response.as_numpy("embedding")

        print("=" * 80)
        print(result)
        print("SMILES              : ", smis_input)
        print("Raw Content length  : ", len(embeddings[0][0]))

        embeddings = io.BytesIO(embeddings[0][0])
        if format == 'h5':
            conv_embeddings = h5.File(embeddings)
            conv_embeddings = conv_embeddings['embeddings']
        elif format == 'npz':
            conv_embeddings = np.load(embeddings)['embeddings']

        print("Embedding           : ", conv_embeddings.shape, type(conv_embeddings))
        print("=" * 80)

if __name__ == '__main__':
    test_encoder_n_decoder()

Observed results and expected behavior

Expected behavior:

Actual behavior:

Environment

rilango commented 1 year ago

This issue reproducible in v0.3.0 too.

pziecina-nv commented 1 year ago

Hi @rilango, thank you for your detailed reproduction scripts.

In your codes, truncation occurs during wrapping npz payload to numpy array.

np.array([[output]], dtype=bytes)

When using the bytes dtype, numpy removes trailing \x00 bytes. Therefore, for arbitrary bytes, it is required to use object dtype.

https://triton-inference-server.github.io/pytriton/0.3.0/binding_models/#defining-inputs-and-outputs

github-actions[bot] commented 1 year ago

This issue is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] commented 11 months ago

This issue was closed because it has been stalled for 7 days with no activity.