onnx / sklearn-onnx

Convert scikit-learn models and pipelines to ONNX
Apache License 2.0
557 stars 104 forks source link

Wrong predictions with `LinearDiscriminantAnalysis` with 2 classes #1116

Closed FrancescMartiEscofetQC closed 2 months ago

FrancescMartiEscofetQC commented 4 months ago

When converting a trained LinearDiscriminantAnalysis with only two classes the output probabilities are wrong.

Code

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import numpy as np
from skl2onnx.convert import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnxruntime as rt

m = LinearDiscriminantAnalysis()
rng = np.random.default_rng(7)
X = rng.standard_normal((1000,10))
y = rng.integers(0,2,1000)
m.fit(X,y)
m.predict_proba(X)

Outputs:

array([[0.47945511, 0.52054489],
       [0.50213593, 0.49786407],
       [0.44575789, 0.55424211],
       ...,
       [0.46666746, 0.53333254],
       [0.54823724, 0.45176276],
       [0.49074771, 0.50925229]])

Converting it to onnx:

onnx_model = convert_sklearn(m, initial_types=[("X", FloatTensorType([None, 10]))], options={"zipmap": False})
sess = rt.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
sess.run(["probabilities"], {"X": X.astype(np.float32)})

Outputs:

[array([[0.4589795 , 0.5410205 ],
        [0.5042718 , 0.49572822],
        [0.39277768, 0.6072224 ],
        ...,
        [0.4336299 , 0.5663701 ],
        [0.5955848 , 0.4044152 ],
        [0.48150176, 0.51849824]], dtype=float32)]

If the output has more classes the converter works fine.

Versions: skl2onnx: 1.17.0 sklearn: 1.5.1 python: 3.12.4

xadupre commented 3 months ago

Thanks for reporting this. It seems huge. This model should convert with double, can you change FloatTensorType into DoubleTensorType? Any matrix with different order of magnitude for the weight introduces huge discrepancies when switching from double to float. See https://onnx.ai/sklearn-onnx/auto_examples/plot_gpr.html.

xadupre commented 2 months ago

I'll close the issue. Feel free to reopen it.