triton-inference-server / fil_backend

FIL backend for the Triton Inference Server
Apache License 2.0
68 stars 35 forks source link

[BUG] getting error from `simple_xgboost_example` notebook #166

Closed rnyak closed 2 years ago

rnyak commented 2 years ago

I am getting the following error when executing this line predictions = [round(value) for value in result_http] in the example notebook:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_394/256293315.py in <module>
      1 # Check that we got the same accuracy as previously
----> 2 predictions = [round(value) for value in result_http]

/tmp/ipykernel_394/256293315.py in <listcomp>(.0)
      1 # Check that we got the same accuracy as previously
----> 2 predictions = [round(value) for value in result_http]

TypeError: type numpy.ndarray doesn't define __round__ method

I can avoid that with numpy.round() but then I got the following error when I calculate the accuracy_score:

accuracy = accuracy_score(y_test, predictions)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_394/3225859865.py in <module>
      1 # # Check that we got the same accuracy as previously
      2 # predictions = [round(value) for value in result_http]
----> 3 accuracy = accuracy_score(y_test, predictions)
      4 print("Accuracy: {:.2f}".format(accuracy * 100.0))

/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py in accuracy_score(y_true, y_pred, normalize, sample_weight)
    209 
    210     # Compute accuracy for each possible representation
--> 211     y_type, y_true, y_pred = _check_targets(y_true, y_pred)
    212     check_consistent_length(y_true, y_pred, sample_weight)
    213     if y_type.startswith("multilabel"):

/usr/local/lib/python3.8/dist-packages/sklearn/metrics/_classification.py in _check_targets(y_true, y_pred)
     91 
     92     if len(y_type) > 1:
---> 93         raise ValueError(
     94             "Classification metrics can't handle a mix of {0} and {1} targets".format(
     95                 type_true, type_pred

ValueError: Classification metrics can't handle a mix of binary and multilabel-indicator targets

Looks like the reason for that error is because result_http has 2d arrays not 1d arrays as below:

array([[1., 1.],
       [1., 0.],
       [1., 1.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)