jpmml / jpmml-lightgbm

Java library and command-line application for converting LightGBM models to PMML
GNU Affero General Public License v3.0
174 stars 58 forks source link

support cross_entropy as objective for lightgbm #50

Closed panlanfeng closed 3 years ago

panlanfeng commented 3 years ago

Hello,

Thanks for the great work on this project. I was wondering if supporting cross entropy objective in your supporting roadmap or not. I have a use case that I need to use numeric probability labels in [0, 1]. I got the following error message. Could you help to take a look? thanks!

Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: cross_entropy
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

Exception in thread "main" java.lang.IllegalArgumentException: cross_entropy
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)
vruusmann commented 3 years ago

Does the cross_entropy objective function (aka xentropy) also use the sigmoid function for calculating probabilities?

If it does, then try simply inserting cross_entropy here: https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523

Something like this:

switch(objective){
  // BinaryLogloss
  case "binary":
  case "cross_entropy":
    return new BinomialLogisticRegression(average_output, config.getDouble("sigmoid"));
}

If you rebuild the project, and re-do the conversion, then does the PMML model make correct predictions or not?

panlanfeng commented 3 years ago

Looks like cross_entropy is not using sigmoid. I made the change as you suggested and get the following error when converting


Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: sigmoid
        at org.jpmml.lightgbm.Section.get(Section.java:106)
        at org.jpmml.lightgbm.Section.get(Section.java:100)
        at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

Exception in thread "main" java.lang.IllegalArgumentException: sigmoid
        at org.jpmml.lightgbm.Section.get(Section.java:106)
        at org.jpmml.lightgbm.Section.get(Section.java:100)
        at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
        at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
        at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
        at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
        at org.jpmml.lightgbm.Main.run(Main.java:137)
        at org.jpmml.lightgbm.Main.main(Main.java:127)

According to this line, cross entropy is directly doing the calculation instead of calling sigmoid function and it does not take sigmoid parameter as in binary classification.

panlanfeng commented 3 years ago

I was able to make it generate the correct score after making the following change to https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523

            case "cross_entropy":
                return new BinomialLogisticRegression(average_output, 1.0 );

I can make a CR for this change if it looks OK to you.

vruusmann commented 3 years ago

return new BinomialLogisticRegression(average_output, 1.0 );

Yes, that appears to be the solution. There is no need for an explicit sigmoid parameter, because the coefficient is hard-coded as 1.

I can make a CR for this change if it looks OK to you.

Not needed - I'll do a proper cross_entropy support with test cases for the next release myself.

In the meantime, you can keep using your patched codebase.

panlanfeng commented 3 years ago

Thanks! I was also wondering if it is possible to also add this cross entropy support to history version 1.2. as well? Ask because our team are still using version 1.2.. It is OK if there is no such plan.

vruusmann commented 3 years ago

I was also wondering if it is possible to also add this cross entropy support to history version 1.2.* as well?

I'll see if the 1.2.X development branch has the same API available that is being "touched" here. If it is, I'll implement the change in 1.2.X, and then merge forward to 1.3.X.

vruusmann commented 3 years ago

The fix is available both in JPMML-LightGBM 1.2.15 and 1.3.10.