Open zhuang-hao-ming opened 1 year ago
Can you please share more details on dataset where you observe this difference?
thank you very much for your response. I have shared the dataset that can observed this difference in the google driver
https://drive.google.com/file/d/1FGDFcUM1-XnIUCt3jXQ9hGANQFF2vIeU/view?usp=sharing
the code for producing this difference are shown below. the one with sklearnex patch get mse (67569.56223138234), while the one without get mse (20874.257079172243).
x_train, y_train, x_test, y_test = np.load('./data.npz')['x_train'], np.load('./data.npz')['y_train'], np.load('./data.npz')['x_test'], np.load('./data.npz')['y_test']
# from sklearnex import patch_sklearn # patch or not
# patch_sklearn()
from sklearn.ensemble import RandomForestRegressor
params = {
'n_estimators': 150,
'random_state': 44,
'n_jobs': -1,
'random_state': 1
}
start = timer()
rf = RandomForestRegressor(**params).fit(x_train, y_train)
train_patched = timer() - start
print(f"Intel® extension for Scikit-learn time: {train_patched:.2f} s")
y_pred = rf.predict(x_test)
mse_opt = metrics.mean_squared_error(y_test, y_pred)
print(f'Intel® extension for Scikit-learn Mean Squared Error: {mse_opt}')
Hi @zhuang-hao-ming thank you for reporting this issue. I will start working on this today and firstly will try to reproduce your observations. I have just opened a PR in the oneDAL repo that will change how we sample our features on node splits https://github.com/oneapi-src/oneDAL/pull/2292 which could have an impact here. I will update you once I know more.
Describe the bug
the random forest regressor with sklearnex patch produce much larger MSE than the original sklearn random forest regressor.
The example in the document, compares implementations with/without patch by calling patch_sklearn()/unpatch_sklearn(), which produce similar MSE. However, use sklearn directly produce much less MSE.
To Reproduce
remove patch_sklearn()/unpatch_sklearn() in the random forest example in the document can produce the error
Expected behavior Describe what your are expecting from steps above
Output/Screenshots If applicable, add output/screenshots to help explain your problem.
Environment: