ELS-RD / transformer-deploy

Efficient, scalable and enterprise-grade CPU/GPU inference server for 🤗 Hugging Face transformer models 🚀
https://els-rd.github.io/transformer-deploy/
Apache License 2.0
1.64k stars 150 forks source link

Attempting to run T5 ORT model in Triton inference server #157

Open samiur opened 1 year ago

samiur commented 1 year ago

Hi there,

Thanks again for this library!

We're trying to convert a fine-tuned T5 model to ONNX and run it in Triton. We've managed to convert the model to ONNX and use the T5 notebook guide to run the model just fine in python.

But trying to get it to run in Triton has been a challenge. In particular, we're not sure how to get past_key_values to be passed through in Triton. We have the decoder config as follows:

name: "t5-dec-if-node_onnx_model"
max_batch_size: 0
platform: "onnxruntime_onnx"
default_model_filename: "model.bin"
input [
    {
        name: "input_ids"
        data_type: TYPE_INT32
        dims: [ -1, -1 ]
    },
    {
        name: "encoder_hidden_states"
        data_type: TYPE_FP32
        dims: [ -1, -1, 2048 ]
    },
    {
        name: "enable_cache"
        data_type: TYPE_BOOL
        dims: [ 1 ]
    },

        {
            name: "past_key_values.0.decoder.key"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.decoder.value"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.encoder.key"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.encoder.value"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        }
     ...
]
output [
    {
        name: "logits"
        data_type: TYPE_FP32
        dims: [ -1, -1, 32128 ]
    }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]

And when we do the following:

input_1 = tritonclient.http.InferInput(name="input_ids", shape=(1, 24), datatype="INT32")
input_2 = tritonclient.http.InferInput(name="encoder_hidden_states", shape=(1, 24, 2048), datatype="FP32")
input_3 = tritonclient.http.InferInput(name="enable_cache", shape=(1, ), datatype="BOOL")

input_1.set_data_from_numpy(input_ids)
input_2.set_data_from_numpy(encoder_hidden_states)
input_3.set_data_from_numpy(np.asarray([True]))

result = triton_client.infer(
    model_name='t5-dec-if-node_onnx_model', 
    inputs=[input_1, input_2, input_3], 
    outputs=[tritonclient.http.InferRequestedOutput(name="logits", binary_data=False)]
)

We get this error:

InferenceServerException: [request id: <id_unknown>] expected 99 inputs but got 3 inputs for model 't5-dec-if-node_onnx_model'

Any idea how we can fix this?

ayoub-louati commented 1 year ago

Hello, Thanks for trying our library, We are actually working on adding T5 officialy in the convert script so that you can do conversion with one line command, It will be added very soon (especially onnx conversion maybe in less than a week), but if you want I can help you with the triton configuration (it is a little bit complicated).