guitargeek / XGBoost-FastForest

Minimal library code to deploy XGBoost models in C++.
MIT License
86 stars 30 forks source link

FastForest inference decripincies with XGBoost #21

Closed andriiknu closed 11 months ago

andriiknu commented 11 months ago

Hi!

We've integrated your library into our project aiming to achieve machine learning results that closely align with an alternative XGBoost implementation.

The basic pipeline is that we dump the XGBoost JSON model to txt and use it to perform inference. We've encountered discrepancies when comparing the results of our FastForest model with those of XGBoost.

Throughout our experiments, we've observed instances (with certain input vectors) where the FastForest model produces unexpected results, differing from XGBoost.

We have reduced the problem to a single input vector. This is a self-contained minimal repo to reproduce the problem: https://github.com/andriiknu/fastforest_issue/tree/master

Thank you in advance! Looking forward to any assistance.

guitargeek commented 11 months ago

Hi, thanks for your report! It looks to me that these are floating point precision problems.

I simplified your reproducer a bit to be standalone in Python using cppyy (or ROOT alternatively), and also exporting the txt on the fly. For simplicity, I also removed all but the first tree in the forest from the JSON.

Here is the new input file: model.json.

My finding is that this is not a fastforest problem, but an XGBoost floating point precision problem. First of all, it creates floating point inconsistencies between the json and the txt, so it's I think not correct when you say the txt conversion by xgboost is not the problem.

But it's also not the only problem: I think XGBoost also has more internal floating point errors when reading the JSON. In my modified example, I chose for both input_1 and input_2 the value of f14 such that it fell just below the relevant decision boundary for f14 in both the JSON and txt case (you can search the files for the first digits of the f14 value to find these cuts if you want to check). It seems that XGBoost gets the answer wrong if the feature value is very close to the boundary.

But is this really a problem? I mean machine learning is not that precise anyway, and as you see even XGBoost itself doesn't care about precision problems, which introduces inconsistencies/errors that fastforest doesn't have apparently. FastfForest is also using single precision for performance reasons.

from xgboost import XGBClassifier
import cppyy # One could also use ROOT for the C++ interpreter
import numpy as np

model = XGBClassifier()
model.load_model("model.json")

def inv(x):
    return -np.log(1./ x - 1.0)

input_1 = np.array(
    [
        [
            1.7031662464141846,
            1.6687114238739014,
            1.5754574537277222,
            2.900691509246826,
            100.23957824707031,
            155.67808532714844,
            287.4523620605469,
            155.5530548095703,
            50.84375,
            145.75,
            54.15625,
            46.09375,
            0.1103515625,
            0.2412109375,
            0.9990234375, # This is the culprit!
            0.7802734375,
            0.2880859375,
            0.01377105712890625,
            0.0240936279296875,
            0.12042236328125,
        ]
    ]
)

# The relevant decision boundary according to the JSON is 0.99902344, and
# according to the txt it's 0.999023438. This is already inconsistent, so
# indeed there is a problem with the conversion to text file by XGBoost.

# In fact, f14 for both input_1 and input_2 is below the decision boundary in
# any case, so it's actually XGBoost that has a problem! It's probably doing
# some transformations with large floating point errors when importing the
# JSON.

score = inv(model.predict_proba(input_1)[:, 1][0])
print(f"score = {score}")

input_2 = np.array(input_1)
input_2[0,14] = 0.9990234

score = inv(model.predict_proba(input_2)[:, 1][0])
print(f"score = {score}")

booster = model._Booster

booster.dump_model("model_1.txt")

prefix = "/home/jonas/code/XGBoost-FastForest/install/" # change this

cppyy.include(prefix + "include/fastforest.h")
cppyy.load_library(prefix + "lib/libfastforest.so")

cppyy.cppdef(
    """
fastforest::FastForest get_model (const std::string& path_to_model, size_t nfeatures) {

    std::vector<std::string> feature_names(nfeatures);
    for (int i = 0; i < nfeatures; ++i) {
        feature_names[i] = "f"+std::to_string(i);
    }

    return fastforest::load_txt(path_to_model, feature_names);
}

"""
)

ff = cppyy.gbl.get_model("model_1.txt", len(input_1[0]))

# score = 1.0 / (1.0 + np.exp(-ff(np.array(input_1[0], dtype=np.float32), 0.0)))
score = ff(np.array(input_1[0], dtype=np.float64), 0.0)
print("score =", score)

score = ff(np.array(input_2[0], dtype=np.float64), 0.0)
print("score =", score)

Here is the output I get, first two are xgboost and the second two are fastforest:

score = 0.3120000909889086
score = -1.4485712587158626
score = -1.44857132
score = -1.44857132
guitargeek commented 11 months ago

Here is a standalone Python code to showcase the XGBoost problems:

from xgboost import XGBClassifier
import numpy as np
from sklearn.datasets import make_classification

X, y = make_classification(
    n_samples=100, n_features=1, n_informative=1, n_redundant=0, n_clusters_per_class=1, random_state=1337
)

model = XGBClassifier(max_depth=1, n_estimators=1)
model.fit(X, y)

model._Booster.save_model("model.json")
model._Booster.dump_model("model_dump.json", dump_format="json")
model._Booster.dump_model("model_dump.txt", dump_format="text")

model_1 = XGBClassifier()
model_1.load_model("model.json")

def predict(model, x):
    return model.predict(np.array([[x]]))

# This is the split value in the model.json (maybe you need to change this
# number if your random number generation gave you a different model)
split = -0.47289193

# So the way how the split value is stored suggests that the comparison will be
# accurate to the 7th significant digit. Let's try to change it to get two
# values up and down of the split:

# Right side of the split. Should predict class zero but it doesn't!
print(predict(model_1, -0.47289194))

# Correctly identified as on the left side of the split.
print(predict(model_1, -0.47289192))

You'll see that: 1) The values in the files output by save_model and dump_model are not identical because of different floating point precision 2) XGBoost has something going on that makes to comparisons at splits less precise than suggested by the number of significant digits in the JSON file you get by save_mode. FastForest doesn't have this problem and has more accurate cut comparisons, hence you get different results if one of the feature values is very close to a split.

I don't think there is anything we can do on the FastForest side, because trying to be floating-point error-compatible with XGBooost would be a pain and not required for practical purposes.

I understand it's annoying for you though if you want to exactly reproduce the results. But maybe now that we understand where these differences come from, they can be tolerated?

eguiraud commented 11 months ago

Thank you for looking into this, and the quick reaction time!

About the latter reproducer, single-precision simply cannot represent the difference between -0.47289194 and -0.47289192 (nor I would expect single-precision to correctly represent differences on the 9th significant digit -- they are effectively different spellings for the same number):

// -0.4728919{2,3,4}f all have the same single-precision floating point representation
root [4] std::cout << std::setprecision(30) << -0.47289194f << std::endl;
-0.47289192676544189453125
root [5] std::cout << std::setprecision(30) << -0.47289192f << std::endl;
-0.47289192676544189453125
root [6] std::cout << std::setprecision(30) << -0.47289193f << std::endl;
-0.47289192676544189453125

So I think this is less interesting of a case than the one originally presented.

eguiraud commented 11 months ago

But is this really a problem?

Only insofar as we want good agreement between the results of ROOT's AGC implementation and the reference implementation.

The annoying aspect of the issue is that due to the non-linear nature of the models, small discrepancies compatible with numerical errors can result in completely different probability scores, which in turn results in significantly different analysis results -- in a way that it's not straightforward to account for.

But it's not FastForest's problem! :) It looks like we'll have to revise our definition of "good agreement" for the ML component of the AGC.

guitargeek commented 11 months ago

Thanks @eguiraud for pointing that out! In fact I assumed that XGBoost uses double precision internally, that's why I was surprised to see the results. But clearly it does not, so it's actually comparable with fastforest, which also uses single precision.

There was also a difference in the logic I spotted and resolved: https://github.com/guitargeek/XGBoost-FastForest/commit/2a986401ecbe6d576b9472c025b77a89b3b1cfb6

@andriiknu, maybe the differences go away with fastforest master that includes this commit?

If yes, please let me know so I can close this issue.

andriiknu commented 11 months ago

Hi! Thank you a lot for digging deep into this problem and detailed explanation! Your latest commit has fixed discrepancies. Probability scores are very close (within a tolerance of 1e-8)!

guitargeek commented 11 months ago

Great, that's good to hear and fortunately the problem was not difficult to fix in the end :D