triton-inference-server / fil_backend

FIL backend for the Triton Inference Server
Apache License 2.0
67 stars 35 forks source link

Triton 24.04: Incorrect input type passed to GTIL predict() #391

Closed casassg closed 2 days ago

casassg commented 1 month ago

We are finding an issue on FIL backend after 24.04 release. Same model artifact works on 24.01 with no issue.

    "error": "failed to infer: rpc error: code = Internal desc = inference failed: UNKNOWN_MODEL_ERROR: http.NewRequest: rpc error: code = Internal desc = [20:25:43] /rapids_triton/build/_deps/treelite-src/src/gtil/predict.cc:336: Incorrect input type passed to GTIL predict(). Expected: float64, Got: float32"

Model: LightGBM model exported

Config.pbtxt:

backend: "fil"
max_batch_size: 512
dynamic_batching {}
instance_group [{ kind: KIND_AUTO }]
input [
    {
        name: "input__0"
        data_type: TYPE_FP32
        dims: [235]
    }
]
output [
    {
        name: "output__0"
        data_type: TYPE_FP32
        dims: [1]
    },
    {
        name: "treeshap_output"
        data_type: TYPE_FP32
        dims: [236]
    }
]
parameters [
    {
        key: "model_type"
        value: { string_value: "lightgbm" }
    },
    {
        key: "output_class"
        value: { string_value: "false" }
    }
]

Suspicion is it has to do with treelite 4.0 release as code was introduced in https://github.com/dmlc/treelite/pull/528

Theory is 4.0 implements fp64 compatibility but removes compatibility w fp32

hcho3 commented 1 month ago

Thanks for raising the issue. I'd like to start working on a fix soon.

Question: In the model repository, did you put the LightGBM model file (model.txt) or the converted Treelite checkpoint file (checkpoint.tl) ?

casassg commented 1 month ago

@hcho3 its a model.txt exported. In case its useful here's the head

tree
version=v3
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=234
objective=binary sigmoid:1
hcho3 commented 1 month ago

@casassg When you say "export", do you mean exporting the model from LightGBM, or exporting it from Treelite? Can you post the code snippet for exporting the model?

casassg commented 4 weeks ago

I mean its saving the lightgbm model:

model: lgb.LGBMModel
model.booster_.save_model(tmp_model_file)
hcho3 commented 6 days ago

@casassg Apologies for the delay. I recently began troubleshooting the issue. I am currently having trouble reproducing the error on my end. Can you look at my setup and see how it differs from yours?

Training script, using LightGBM 4.4.0 and scikit-learn 1.4.1:

import lightgbm
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=1000, n_features=235, n_informative=200)
print(X.dtype, y.dtype)  # Prints: float64 int64

dtrain = lightgbm.Dataset(X, label=y)
params = {
    "num_leaves": 31,
    "metric": "binary_logloss",
    "objective": "binary",
}
bst = lightgbm.train(
    params,
    dtrain,
    num_boost_round=10,
    valid_sets=[dtrain],
    valid_names=["train"],
    callbacks=[lightgbm.log_evaluation()],
)
bst.save_model("example/1/model.txt")

First few lines of model.txt:

tree
version=v4
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=234
objective=binary sigmoid:1

Triton-FIL model configuration (config.pbtxt):

backend: "fil"
max_batch_size: 512
dynamic_batching {}
instance_group [{ kind: KIND_AUTO }]
input [
    {
        name: "input__0"
        data_type: TYPE_FP32
        dims: [235]
    }
]
output [
    {
        name: "output__0"
        data_type: TYPE_FP32
        dims: [1]
    },
    {
        name: "treeshap_output"
        data_type: TYPE_FP32
        dims: [236]
    }
]
parameters [
    {
        key: "model_type"
        value: { string_value: "lightgbm" }
    },
    {
        key: "output_class"
        value: { string_value: "false" }
    }
]

I launched the Triton server locally using the Docker container:

docker run --rm -it --gpus '"device=0"' --network host \
    -v $PWD/models:/models nvcr.io/nvidia/tritonserver:24.05-py3 \
    tritonserver --model-repository=/models

Using the following client inference script, I was able to get the result:

import numpy as np
import tritonclient.http as triton_http

x = np.zeros((1, 235), dtype=np.float32)

client = triton_http.InferenceServerClient(url="localhost:8000")
triton_input = triton_http.InferInput("input__0", x.shape, "FP32")
triton_input.set_data_from_numpy(x)
output0 = triton_http.InferRequestedOutput("output__0")
output_treeshap = triton_http.InferRequestedOutput("treeshap_output")

r = client.infer(
    "example",
    model_version="1",
    inputs=[triton_input],
    outputs=[output0, output_treeshap],
)
print(r.as_numpy("output__0"))
print(r.as_numpy("treeshap_output"))
casassg commented 3 days ago

Trying to run your example, One thing I notice is I run the 3.3.5 version of LightGBM vs 4.4.0:

import lightgbm as lgb
from sklearn.datasets import make_classification
print(lgb.__version__)  # Prints: 3.3.5

X, y = make_classification(n_samples=1000, n_features=235, n_informative=200)
print(X.dtype, y.dtype)  # Prints: float64 int64

m = lgb.LGBMClassifier()
m.fit(X, y)

import os
os.makedirs("models/example/1", exist_ok=True)
m.booster_.save_model("models/example/1/model.txt")
tree
version=v3
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=234
objective=binary sigmoid:1
docker run --rm -d -p 8000:8000 -v $PWD/models:/models nvcr.io/nvidia/tritonserver:24.05-py3 tritonserver --model-repository=/models
import numpy as np
import tritonclient.http as triton_http

x = np.zeros((1, 235), dtype=np.float32)

client = triton_http.InferenceServerClient("localhost:8000")
triton_input = triton_http.InferInput("input__0", x.shape, "FP32")
triton_input.set_data_from_numpy(x)
output0 = triton_http.InferRequestedOutput("output__0")
output_treeshap = triton_http.InferRequestedOutput("treeshap_output")

r = client.infer(
    "example",
    model_version="1",
    inputs=[triton_input],
    outputs=[output0, output_treeshap],
)
print(r.as_numpy("output__0"))
print(r.as_numpy("treeshap_output"))

This errors out:

InferenceServerException: [500] [21:15:25] /rapids_triton/build/_deps/treelite-src/src/gtil/predict.cc:391: Incorrect input type passed to GTIL predict(). Expected: float64, Got: float32
hcho3 commented 3 days ago

@casassg Interesting. I was able to reproduce the error when I turned off the GPU and ran the inference on the CPU:

docker run --rm -it --network host \
    -v $PWD/models:/models nvcr.io/nvidia/tritonserver:24.05-py3 \
    tritonserver --model-repository=/models

Error:

tritonclient.utils.InferenceServerException: [500] [23:04:45] /rapids_triton/build/_deps/treelite-src/src/gtil/predict.cc:391: Incorrect input type passed to GTIL predict(). Expected: float64, Got: float32
casassg commented 3 days ago

maybe something in the CPU gtil functions vs the GPU ones?

hcho3 commented 3 days ago

@casassg Yes, the bug is only affecting GTIL. I will make a bug fix soon. In the meanwhile, you can work around the bug by modifying config.pbtxt as follows:

backend: "fil"
max_batch_size: 512
dynamic_batching {}
instance_group [{ kind: KIND_AUTO }]
input [
    {
        name: "input__0"
        data_type: TYPE_FP32
        dims: [235]
    }
]
output [
    {
        name: "output__0"
        data_type: TYPE_FP32
        dims: [1]
    },
    {
        name: "treeshap_output"
        data_type: TYPE_FP32
        dims: [236]
    }
]
parameters [
    {
        key: "model_type"
        value: { string_value: "lightgbm" }
    },
    {
        key: "output_class"
        value: { string_value: "false" }
    },
    {
        key: "use_experimental_optimizations"
        value: { string_value: "true" }
    }
]

Note the addition of use_experimental_optimizations, which instructs the backend to use an alternative GPU implementation.

hcho3 commented 3 days ago

Fix is available at https://github.com/triton-inference-server/fil_backend/pull/394. It will be part of the upcoming release (24.06).