microsoft / LightGBM

A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.
https://lightgbm.readthedocs.io/en/latest/
MIT License
16.59k stars 3.83k forks source link

I cannot reproduce results of quantile regression when using a custom metric or objective. #6062

Open HowardRiddiough opened 1 year ago

HowardRiddiough commented 1 year ago

Description

I am working on a project where we want to make a conservative prediction, aka increase likelihood that the model produces a negative error.

At the moment we are using LightGBM's quantile regression to predict the 90th quantile, I am happy with the results but would like to tweak the loss function ever so slightly to introduce a quadratic loss function to punish large negative errors disproportionately more than small negative errors. As I mentioned earlier we are looking to make conservative predictions but a prediction that is too conservative doesn't deliver any value.

With that end in mind I have been trying to reproduce the stock quantile regressor results, I want to begin with reproducing stock behaviour so I know I have a good foundation to start modifying the quantile loss function from.

When running the examples below you will see that the predictions made by each model do not match. That may be because I haven't constructed the quantile loss function correctly. It may also be that my custom quantile regressor does not calculate tree output in the same way as the stock quantile regressor. I can see in the regression_objective.hpp file that the RegressionQuantileloss class is doing something with percentiles when calculating the tree output that the standard regression loss may not be doing.

Reproducible example

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from lightgbm import LGBMRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

data_url = 'http://lib.stat.cmu.edu/datasets/boston'
raw_df = pd.read_csv(data_url, sep='\s+', skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

# Construct a model using LightGBM's quantile regressor
lgbm_stock = LGBMRegressor(objective='quantile', alpha=0.9, metric='quantile')
lgbm_stock.fit(X_train, y_train)
y_pred_stock = lgbm_stock.predict(X_test)

# Construct a model using a custom objective designed to produce the same predictions as the stock quantile regressor
alpha = 0.9

def quantile_loss_objective(preds, labels):
    error = preds - labels
    gradient = (-alpha * (error < 0) * error - (1 - alpha) * (error >= 0)) * error
    hessian = np.ones_like(error) * alpha

    return gradient, hessian

lgbm_custom = LGBMRegressor(objective=quantile_loss_objective)
lgbm_custom.fit(X_train, y_train)
y_pred_custom = lgbm_custom.predict(X_test)

# Plot predictions on test
plt.hist(y_pred_stock, label='stock', alpha=0.5)
plt.hist(y_pred_custom, label='custom', alpha=0.5)
plt.legend()
plt.title("LightGBM's quantile regressor vs LightGBM regressor with quantile loss objective")
plt.show()

Environment info

python==3.9.5 numpy=1.23.5 matplotlib==3.7.1 lightgbm==3.3.5 pandas==1.5.3 scikit-learn==1.3.0

I am working on a macbook with an M1 chip.

Summary

Thanks for taking a look at this issue and I really would appreciate any help. My questions are as follows:

regressor-outputs

shiyu1994 commented 1 year ago

Note that with quantile loss, LightGBM will not only fit on the gradients and hessians. After the tree is grown, there is another step to renew the leaf values directly with the quantile of the residuals.

https://github.com/microsoft/LightGBM/blob/858eeb54a215ff50fd3ca0033b77d51bf48f2c1f/src/objective/regression_objective.hpp#L540

Thus, though the customized function calculates the gradients and hessians accordingly, with a customized objective function, there's no such renewing step.

HowardRiddiough commented 1 year ago

@shiyu1994 Thanks for the information. Is there any way I can implement this extra step when constructing my model in python?

HowardRiddiough commented 1 year ago

Would you be open to me making an MR for a new regression objective for something like quadratic quantile loss function? Such an objective would be very valuable to my company .

shiyu1994 commented 1 year ago

Would you be open to me making an MR for a new regression objective for something like quadratic quantile loss function? Such an objective would be very valuable to my company.

I'm happy with that. It would be good to hear opinions from @guolinke @jameslamb @jmoralez

mayer79 commented 1 year ago

What target functional $T(Y \mid X)$ would be associated with it @HowardRiddiough ? Put differently, what quantity is being modelled?

Examples for $T$: squared error/Poisson/gamma/logloss are all associated with the (conditional) expectation, pinball loss is associated with the given (conditional) quantile. With Pseuso-Huber loss, things already get tricky.

HowardRiddiough commented 1 year ago

Another option would be to implement a situation where users could use a custom objective in combination with quantile regression, would that be possible?

robert-robison commented 12 months ago

Note that with quantile loss, LightGBM will not only fit on the gradients and hessians. After the tree is grown, there is another step to renew the leaf values directly with the quantile of the residuals.

@shiyu1994 Can you explain in more detail what is actually happening in this renewal step? Are the gradients all scaled by a constant depending on the average absolute value of the residuals?

Edit: Actually found a good explanation here: https://jmarkhou.com/lgbqr/

Basically, after fitting the tree on the gradients, there's an extra step to calculate the empirical quantile of the residuals of the data points within each leaf. I.e., fit the tree using the first order approximation, but do a full derivative calculation to determine the values each leaf should predict in the tree

HowardRiddiough commented 11 months ago

@robert-robison have you found a way to combine a custom metric with quantile regression execution?

robert-robison commented 11 months ago

@HowardRiddiough No. The closest we've been able to come is weighting the gradients by the absolute value of the error. But this is a kind of "squared quantile loss" or something--it won't give you the same results as the quantile objective will.

The renewing step cannot be duplicated (as far as I can tell) within a custom objective function. Would love to hear of better solutions if you (or anyone) has them!