BayesWitnesses / m2cgen

Transform ML models into a native code (Java, C, Python, Go, JavaScript, Visual Basic, C#, R, PowerShell, PHP, Dart, Haskell, Ruby, F#, Rust) with zero dependencies
MIT License
2.82k stars 241 forks source link

add code for multioutput regression #576

Open AaronDavidSchneider opened 1 year ago

AaronDavidSchneider commented 1 year ago

This PR adds the option for multioutput regression in XGBoost. This could work for other booster classes as well. However, there is no general attribute that informs about the number of targets. Thats why this solution is restricted to XGBoost.

Closes #559

AaronDavidSchneider commented 1 year ago

I hereby also add some tests. And here is another simple way to verify that it works:

n_targets = 3
n_features = 3
n_test = 20
n_train = 20

X, y = datasets.make_regression(n_targets=n_targets, n_features=n_features, n_samples=n_train, random_state=1)

multi_class_model_params = {
    'n_estimators': 3,
    'max_depth': 2
}

model = XGBRegressor(**multi_class_model_params).fit(X, y)

code = m2cgen.export_to_python(model, function_name=f'predict')
with open('test_file.py', 'w') as f:
    f.write(code)

from test_file import predict

input = np.random.random((n_test, n_targets))

closenes = []
for i, input_i in enumerate(input):
    hardcoded = predict(input_i)
    apipred = model.predict(input_i.reshape((1, n_features)))
    closenes.append(np.allclose(apipred, hardcoded))

print(f'All close: {np.all(closenes)}, fraction of close: {np.sum(closenes)/n_test}')

Which works well for me.

I also ran the make pre-pr command, which reported pass for all tests but two on my MacBook. I don't know what happened to the two tests where it failed, but I suspect that it may be related to different versions of python packages being installed on my machine and the test machine (e.g., these exact same problems also occur on the master branch...).

Unfortunately, the docker command also fails on my machine, which does not seem to be caused by my changes but rather by the docker build command.

I hope that the CI will give more insight!

AaronDavidSchneider commented 1 year ago

Kindly pinging @izeigerman and @StrikerRUS