jpmml / jpmml-evaluator

Java Evaluator API for PMML
GNU Affero General Public License v3.0
895 stars 255 forks source link

How to get the output leaf indices of every trees in a LightGBM/Xgboost model #233

Closed liuyair closed 2 years ago

liuyair commented 3 years ago

While doing prediction, I want to get the output leaf indices of every trees from my PMML LightGBM/Xgboost model. Any index format is OK, including onehot/labelencoded/tree node idx.

The pmml model is generated by Python sklearn package with sklearn2pmml or jpmml-lightgbm.

Actually, my purpose is the same as LGBMClassifier.predict(data, pred_leaf=True) in Python.

How can I do that in Java using JPMML-Evaluator?

vruusmann commented 3 years ago

I want to get the output leaf indices of every trees.. any index format is OK, including onehot/labelencoded/tree node idx.

In PMML representation, tree nodes are identified by the Node@id attribute: http://dmg.org/pmml/v4-4-1/TreeModel.html#xsdElement_Node

This is an optional attribute; if missing, the PMML engine shall assign "virtual" 1-based integer identifiers.

How can I do that in Java using JPMML-Evaluator?

The results from tree-based models (decision trees, decision tree ensembles) typically implement the org.jpmml.evaluator.tree.HasDecisionPath marker interface: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/HasDecisionPath.java

This marker interface, possibly in combination with the model-level org.jpmml.evaluator.tree.HasNodeRegistry marker interface, should provide all information for achieving custom application goals: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/HasNodeRegistry.java

The situation is a bit more complicated with tree ensemble models (XGBoost, LightGBM, GBT, Random Forest), because the prediction result is "layered", which means that the o.j.e.tree.HasDecisionPath object is wrapped inside an org.jpmml.evaluator.mining.HasSegmentation object: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/mining/HasSegmentation.java

The internal structure of the o.j.e.mining.HasSegmentation object depends on the mining function. For regression-type decision tree ensembles it's simpler (booster only), for classification-type decision tree ensembles it's more complex (booster followed by boosted score normalizer).

vruusmann commented 2 years ago

TLDR: Use the following approach:

  1. Start with simple decision trees. For example, any of my DecisionTreeAudit.pmml models.
  2. Make a prediction using a simple decision tree, and cast its target value to org.jpmml.evaluator.tree.HasDecisionPath object.
  3. Extract HasDecisionPath#getNode() and process it.
  4. Move to a more complex example. For example, some regression-type XGBoost or LightGBM model such as my XGBoostAuto.pmml or LightGBMAuto.pmml models.
  5. Make a prediction using it, and cast the target value to org.jpmml.evaluator.mining.HasSegmentation object.
  6. Extract individual segment targets, and process them.
  7. Move to an even more complex example. For example, classification-type XGBoost or LigthGBM models.
  8. Make a prediction using it, cast the target value to o.j.e.mining.HasSegmentation. Extract the partial result corresponding to the booster component, and process according to steps five and six above.
vruusmann commented 2 years ago

See also th following two sample projects about dealing with decision tree ensemble (RF) models:

  1. https://github.com/vruusmann/rf_feature_impact
  2. https://github.com/vruusmann/rf_recordcount

Closing this issue, as the provided guidance should be sufficient to continue on your own. Feel free to ask clarifying/follow-up questions if necessary.

liuyair commented 2 years ago

Thanks for your detailed reply and guidance. During past days I've tried to implement the 'node-id extraction' function following your approach guidance, and I've got some new problems.

PS1: Currently, I just follow your project https://github.com/vruusmann/rf_feature_impact to manage to get the leaf node ids. The code blocks below are from this project without any edition except my custom data, model and System.out.println(). Thanks for your good reference.

PS2: All the codes below are working with jpmml 1.4.15. When I use the latest 1.5.16, there seems to be many breaking changes in your rf_feature_impact project. I also tried to reproduce the whole process under 1.5.16, but the targetValue became a double(regression model) or ProbabilityDistribution(classification model) object and couldn't be cast to HasSegmentation or HasDecisionPath by:

HasSegmentation hasSegmentation = (HasSegmentation)targetValue;
HasDecisionPath hasDecisionPath = (HasDecisionPath)targetValue;

We may discuss this version problem later.


I've arrived at your step 8, using one sample data to do prediction and get my LGB classification model's org.jpmml.evaluator.mining.HasSegmentation output, and got the target values inside each SegmentResult:

results = evaluator.evaluate(arguments);
Object targetValue = results.get(targetField.getName());
HasSegmentation hasSegmentation = (HasSegmentation)targetValue;
Collection<? extends SegmentResult> segmentResults = hasSegmentation.getSegmentResults();

for(SegmentResult segmentResult : segmentResults){
    Object segmentTargetValue = segmentResult.getTargetValue();

And I checked the detail of this segmentTargetValue in your funciton computeFeatureContributions :

static
private List<Contribution> computeFeatureContributions(String segmentId, Number weight, Object targetValue, String targetClass){
    HasDecisionPath hasDecisionPath = (HasDecisionPath)targetValue;  // Here targetValue == segmentTargetValue
    System.out.println("segmentId: " + segmentId);
    System.out.println("targetValue: " + targetValue);
    System.out.println("targetValue type: " + targetValue.getClass().toString());
    System.out.println("hasDecisionPath node: " + hasDecisionPath.getNode());  // Gets the winning node.

The printed result(an example info from my LGB's 500th tree):

>>> segmentId: 500
>>> targetValue: {result=-3.811698813155931E-4, entityId=1}
>>> targetValue type: class org.jpmml.evaluator.tree.TreeModelEvaluator$1
>>> hasDecisionPath node: org.dmg.pmml.tree.ComplexNode@60e9df3c

I want to get the output leaf indices of every trees from my PMML LightGBM/Xgboost model

targetValue has an element entityId, which ranges from 1~7 as my LGB model's num_leaves=7. This entityId seems like what I want. However, I used the same sample data and did prediction in Python by lightgbm.Booster.predict(data, pred_leaf=True). The leaf node ids are totally different from what I got in Java(the prediction probability values are same). And I checked these ids from Python that they are following the rule of this kind of tree index order: tree

In PMML representation, tree nodes are identified by the Node@id attribute: http://dmg.org/pmml/v4-4-1/TreeModel.html#xsdElement_Node

This is an optional attribute; if missing, the PMML engine shall assign "virtual" 1-based integer identifiers.

My pmml model file doesn't contain the Node@id attribute. I wonder if the entityId I've got in Java is exactly the ' "virtual" 1-based integer identifiers' you mentioned? Can I trust this entityId and use it to correctly represent the prediction output leaves?

Another problem is that, if the entityId is my goal, how can I get it in targetValue?

targetValue is the type of class org.jpmml.evaluator.tree.TreeModelEvaluator and it doesn't have methods like get() or getId(). The getNode() method of hasDecisionPath can only return a org.dmg.pmml.tree.ComplexNode with a 'strange' id 60e9df3c

System.out.println("targetValue: " + targetValue);
System.out.println("targetValue type: " + targetValue.getClass().toString());
System.out.println("hasDecisionPath node: " + hasDecisionPath.getNode());  // Gets the winning node.

>>> targetValue: {result=-3.811698813155931E-4, entityId=1}
>>> targetValue type: class org.jpmml.evaluator.tree.TreeModelEvaluator$1
>>> hasDecisionPath node: org.dmg.pmml.tree.ComplexNode@60e9df3c

(I'm not sure whether this problem is a bit silly as I'm a Java rookie starting Java exactly from this project...)


As a reference, here is a part from my pmml model file that shows some basic info and the structure of my 500th tree segmentation:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
    <Header>
        <Application name="JPMML-LightGBM" version="1.3-SNAPSHOT"/>
        <Timestamp>2021-11-29T11:29:25Z</Timestamp>
    </Header>
    <DataDictionary>
        <DataField name="_target" optype="categorical" dataType="integer">
            <Value value="0"/>
            <Value value="1"/>
        </DataField>
        <DataField name="feature_001" optype="continuous" dataType="double">
            <Interval closure="closedClosed" leftMargin="0.0" rightMargin="11.232960820144106"/>
            <Value value="NaN" property="missing"/>
        </DataField>
        ......

                        <Segment id="500">
                            <True/>
                            <TreeModel functionName="regression" noTrueChildStrategy="returnLastPrediction">
                                <MiningSchema>
                                    <MiningField name="feature_002"/>
                                    <MiningField name="feature_003"/>
                                    <MiningField name="feature_004"/>
                                    <MiningField name="feature_005"/>
                                    <MiningField name="feature_006"/>
                                    <MiningField name="feature_007"/>
                                </MiningSchema>
                                <Node score="-3.811698813155931E-4">
                                    <True/>
                                    <Node score="-8.988185833947195E-4">
                                        <SimplePredicate field="feature_007" operator="greaterThan" value="0.5638774799243619"/>
                                        <Node score="-0.006818797571779276">
                                            <SimplePredicate field="feature_006" operator="greaterThan" value="48.50000000000001"/>
                                        </Node>
                                        <Node score="0.001541678356026245">
                                            <SimplePredicate field="feature_002" operator="greaterThan" value="0.8183625000000001"/>
                                        </Node>
                                    </Node>
                                    <Node score="-0.006749477204381504">
                                        <SimplePredicate field="feature_005" operator="greaterThan" value="0.025786767554062152"/>
                                        <Node score="0.009428389314363324">
                                            <SimplePredicate field="feature_003" operator="greaterThan" value="894.5000000000001"/>
                                        </Node>
                                    </Node>
                                    <Node score="0.010359079888978103">
                                        <SimplePredicate field="feature_004" operator="greaterThan" value="8.500000000000002"/>
                                    </Node>
                                </Node>
                            </TreeModel>
                        </Segment>
        ......

Version infomation [Java side] (1) Java: 1.8 (2) jpmml-evaluator: 1.4.15 / 1.5.16 (3) pmml model tramsformer: jpmml-lightgbm 1.3 (4) pmml ver: 4.3 (Original tranfromed pmml model is 4.4, manually changed pmml's headline to 4.3 as 4.4 is unsupported in jpmml-evaluator: 1.4.15)

[Python side] (5) Python: 3.7 (6) LightGBM: 3.2.1

Please feel free to tell me if you need more info about my program. Thanks.

vruusmann commented 2 years ago

All the codes below are working with jpmml 1.4.15. When I use the latest 1.5.16, there seems to be many breaking changes.

The most important change between 1.4.X and 1.5.X development branches is that 1.5.X contains many decision tree evaluator implementations, and uses the most "lightweight" implementation that does seem to do the job.

The 1.4.X-compatible decision tree evaluator is org.jpmml.evaluator.tree.ComplexTreeModelEvaluator: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/ComplexTreeModelEvaluator.java

It returns o.j.e.tree.HasDecisionPath-compatible result values in all cases: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/ComplexTreeModelEvaluator.java#L283-L423

The newer & lightweight tree evaluator is org.jpmml.evaluator.tree.SimpleTreeModelEvaluator: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/SimpleTreeModelEvaluator.java

As you already observed, it returns java.lang.Number for regression cases, and java.lang.String for voting-style classification cases: https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/SimpleTreeModelEvaluator.java#L93-L100

Most decision tree evaluation tasks are fully served by the o.j.e.tree.SimpleTreeModelEvaluator. It creates less garbage, and is significantly more performant.

However, you want to access extra information that is not available when using o.j.e.tree.SimpleTreeModelEvaluator. The solution is therefore to manually force the activation of o.j.e.tree.ComplexTreeModelEvaluator.

This can be achieved using the org.jpmml.evaluator.ModelEvaluatorBuilder#setExtraResultFeatures(Set<org.dmg.pmml.ResultFeature>) method. Since you're interested in node identifiers, you'd need to indicate org.dmg.pmml.ResultFeature#ENTITY_ID there. Something like this:

EvaluatorBuilder evaluatorBuilder = new LoadingModelEvaluatorBuilder()
  // THIS!
  .setExtraResultFeatures(EnumSet.of(ResultFeature.Entity_ID))
  .load(new File());

Evaluator evaluator = evaluatorBuilder.build();

The resulting Evaluator will now be doing it best to return target values that implement the org.jpmml.evaluator.HasEntityId marker interface (the org.jpmml.evaluator.tree.HasDecisionPath is one of its sub-marker interfaces).

vruusmann commented 2 years ago

My pmml model file doesn't contain the Node@id attribute.

The JPMML-LightGBM library initializes the Node@id attribute with native LightGBM identifier values: https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/Tree.java#L121 https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/Tree.java#L272

Node identifier may get "erased" during decision tree compaction as implemented by the org.jpmml.lightgbm.visitors.TreeModelCompactor visitor class.

They are required to be present initially: https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L33 https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L37-L39

But they get "erased": https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L77

Decision tree compaction is active by default. If you are interested in preserving LightGBM decision trees in their native layout, then you should disable it by setting the org.jpmml.lightgbm.HasLightGBMOptions#OPTION_COMPACT to false: https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/HasLightGBMOptions.java#L29

For example, if you're converting LightGBM models using the SkLearn2PMML package, then you can toggle this option using the sklearn2pmml.pipeline.PMMLPipeline.configure(**pmml_options) method:

pipeline = PMMLPipeline([
  ("classifier", LGBMClassifier())
])
pipeline.fit(X, y)
# THIS!
pipeline.configure(compact = False)

sklearn2pmml(pipeline, "pipeline.pmml")

Exactly the same applies to XGBoost models - you need to turn off decision tree compaction, which is active by default.