jpmml / jpmml-xgboost

Java library and command-line application for converting XGBoost models to PMML
GNU Affero General Public License v3.0
128 stars 43 forks source link

Support for `survival:aft` objective #63

Closed WeijiaZhang24 closed 1 year ago

WeijiaZhang24 commented 2 years ago

Is there any plan to support Accelerated Failure Time Model for survival analysis in XGBoost?

vruusmann commented 2 years ago

How is survival analysis different from regression analysis? Perhaps the solution is as simple as adding survival:aft as an alias to some existing objective function.

Anyway, if you need this functionality, then it's best if you help yourself. For starters, please provide a reproducible example (toy dataset plus sample Python/R code) that would train a "currently not supported" XGBoost model subtype.

WeijiaZhang24 commented 2 years ago

The regression target y in survival analyisis are censored. This is usually encountered in medical follow-up studies where the patients drop out and we only know that he/she is healthy up until the dropout time, and we do not know when/if the disease returns after the dropout time.

When the data is not censored, y is just a number (in the following example, its upper and lower bounds are both 2) When the data is right censored, y has lower bound but no upper bound (in the following example, we only know that y>3) When the data is left censored, y has upper bound but no lower bound (we only know that y<4, this is rarely encountered in practice) For interval censoring, y falls in an interval ( 4<y<5 in the example code)

Here's some example code:

import numpy as np
import xgboost as xgb
# 4-by-2 Data matrix
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
dtrain = xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
#                         uncensored    right     left  interval
y_lower_bound = np.array([      2.0,     3.0,     0.0,     4.0])
y_upper_bound = np.array([      2.0, +np.inf,     4.0,     5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)

params = {'objective': 'survival:aft',
          'eval_metric': 'aft-nloglik',
          'aft_loss_distribution': 'normal',
          'aft_loss_distribution_scale': 1.20,
          'tree_method': 'hist', 'learning_rate': 0.05, 'max_depth': 2}
bst = xgb.train(params, dtrain, num_boost_round=5,
                evals=[(dtrain, 'train')])
vruusmann commented 2 years ago

Here's the relevant part of XGBoost 1.6.1 codebase: https://github.com/dmlc/xgboost/blob/v1.6.1/src/objective/aft_obj.cu#L107-L116

TLDR: First, collect individual tree predictions into an array, and apply the exp() transformation on them (in order to transform from log scale back to normal scale). Then, apply the Common::Range function to this collection.

@WeijiaZhang24 Basically, what do you expect as a model output in this case? An array of values (not some scalar value)?

vruusmann commented 2 years ago

@WeijiaZhang24 Basically, what do you expect as a model output in this case?

Trying it out myself:

yt = bst.predict(dtrain)
print(yt.shape)
print(yt)

This prints:

(4,)
[0.6762048 0.6762048 0.6762048 0.6762048]

Looks like the expected output is still a scalar?

WeijiaZhang24 commented 2 years ago

Yes, in the case of XGboost, the output is still a scalar which is the time-to-event prediction. However, in general survival analysis, the desired prediction outputs are often survival curves (the probability of survival, i.e., no event, at each time point).

The prediction of survival curves is support in Sciki-Survival (as in the other issues we are discussing in sklearn2pmml repository). For example, for Random Survival Forest they have a predict() function which is the same as xgboost predict, and also a predict_survival_function() function that outputs a step function for the survival curves. See https://scikit-survival.readthedocs.io/en/stable/user_guide/random-survival-forest.html for an example and visualization of the survival curves.

For this reason, maybe focusing on the sklearn-survival package will benefit more people that uses survival analysis.

vruusmann commented 2 years ago

@WeijiaZhang24 Can you perform the following experiment locally?

First, insert a case statement for the survival:aft at the end of this block of case statements: https://github.com/jpmml/jpmml-xgboost/blob/1.7.0/pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java#L534-L538

Something like this:

switch(name_obj){
    case "reg:linear":
    case "reg:pseudohubererror":
    case "reg:squarederror":
    case "reg:squaredlogerror":
    // THIS!
    case "survival:aft":
        return new LinearRegression(name_obj);
}

Then, rebuild the project using Apache Maven as detailed in the README file:

$ mvn clean install

This will give you a custom JPMML-XGBoost library version that should be able to export your survival analysis models as regular regression models.

Now, the goal is to figure out if you can obtain the desired "survival score" by simply exponentiating the "regression score". I can't run this test with the above toy code example, because it trains an empty (no-op) XGBoost model due to the meaninglessness of the training data.