jpmml / jpmml-converter

Java library for authoring PMML
GNU Affero General Public License v3.0
15 stars 4 forks source link

Failing to prune XGBoost tree models #21

Open xiaoqiao21 opened 3 years ago

xiaoqiao21 commented 3 years ago

Hi, I want to use sklearn2pmml() function to convert a PMML file.

I created an issuse below, but I was not able to reopen it so I create this new issue and just copy the content again here. https://github.com/jpmml/jpmml-sklearn/issues/160

Here is my code to create a pipeline. But I saw an error

RuntimeError: The JPMML-SkLearn conversion application has failed. The Java executable should have printed more information about the failure into its standard output and/or standard error streams

How can I solve it? My version is 0.73.1

The standout is

Standard output is empty
Standard error:
Jul 01, 2021 8:33:28 PM org.jpmml.sklearn.Main run
INFO: Parsing PKL..
Jul 01, 2021 8:33:28 PM org.jpmml.sklearn.Main run
INFO: Parsed PKL in 219 ms.
Jul 01, 2021 8:33:28 PM org.jpmml.sklearn.Main run
INFO: Converting PKL to PMML..
Jul 01, 2021 8:33:30 PM org.jpmml.sklearn.Main run
SEVERE: Failed to convert PKL to PMML
java.lang.IllegalArgumentException
    at org.jpmml.converter.visitors.AbstractTreeModelTransformer.initScore(AbstractTreeModelTransformer.java:173)
    at org.jpmml.converter.visitors.TreeModelPruner.exitNode(TreeModelPruner.java:81)
    at org.jpmml.converter.visitors.AbstractTreeModelTransformer.popParent(AbstractTreeModelTransformer.java:61)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:120)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.tree.TreeModel.accept(TreeModel.java:401)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.mining.Segment.accept(Segment.java:235)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.mining.Segmentation.accept(Segmentation.java:185)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:69)
    at org.dmg.pmml.mining.MiningModel.accept(MiningModel.java:349)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.mining.Segment.accept(Segment.java:235)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.mining.Segmentation.accept(Segmentation.java:185)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:69)
    at org.dmg.pmml.mining.MiningModel.accept(MiningModel.java:349)
    at org.jpmml.model.visitors.AbstractVisitor.applyTo(AbstractVisitor.java:320)
    at org.jpmml.xgboost.Learner.encodeMiningModel(Learner.java:354)
    at xgboost.sklearn.BoosterUtil.encodeBooster(BoosterUtil.java:63)
    at xgboost.sklearn.XGBClassifier.encodeModel(XGBClassifier.java:45)
    at xgboost.sklearn.XGBClassifier.encodeModel(XGBClassifier.java:27)
    at sklearn.Estimator.encode(Estimator.java:83)
    at sklearn2pmml.pipeline.PMMLPipeline.encodePMML(PMMLPipeline.java:235)
    at org.jpmml.sklearn.Main.run(Main.java:226)
    at org.jpmml.sklearn.Main.main(Main.java:143)

Exception in thread "main" java.lang.IllegalArgumentException
    at org.jpmml.converter.visitors.AbstractTreeModelTransformer.initScore(AbstractTreeModelTransformer.java:173)
    at org.jpmml.converter.visitors.TreeModelPruner.exitNode(TreeModelPruner.java:81)
    at org.jpmml.converter.visitors.AbstractTreeModelTransformer.popParent(AbstractTreeModelTransformer.java:61)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:120)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.tree.SimpleNode.accept(SimpleNode.java:113)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.tree.TreeModel.accept(TreeModel.java:401)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.mining.Segment.accept(Segment.java:235)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.mining.Segmentation.accept(Segmentation.java:185)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:69)
    at org.dmg.pmml.mining.MiningModel.accept(MiningModel.java:349)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
    at org.dmg.pmml.mining.Segment.accept(Segment.java:235)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
    at org.dmg.pmml.mining.Segmentation.accept(Segmentation.java:185)
    at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:69)
    at org.dmg.pmml.mining.MiningModel.accept(MiningModel.java:349)
    at org.jpmml.model.visitors.AbstractVisitor.applyTo(AbstractVisitor.java:320)
    at org.jpmml.xgboost.Learner.encodeMiningModel(Learner.java:354)
    at xgboost.sklearn.BoosterUtil.encodeBooster(BoosterUtil.java:63)
    at xgboost.sklearn.XGBClassifier.encodeModel(XGBClassifier.java:45)
    at xgboost.sklearn.XGBClassifier.encodeModel(XGBClassifier.java:27)
    at sklearn.Estimator.encode(Estimator.java:83)
    at sklearn2pmml.pipeline.PMMLPipeline.encodePMML(PMMLPipeline.java:235)
    at org.jpmml.sklearn.Main.run(Main.java:226)
    at org.jpmml.sklearn.Main.main(Main.java:143)
vruusmann commented 3 years ago

Looks like you've trained an XGBoost model, which contains no-op nodes.

The JPMML-XGBoost automatically tries to eliminate those nodes (because they are provably unreachable under any and all scenarios) by applying a special tree model pruning algorithm implementes as org.jpmml.converter.visitors.TreeModelPruner.

I was sure that the tree pruning code will always succeed. However, you've managed to train an XGBoost model that contains such an unusual "internal structure" that the tree pruning code still fails.

Can you share your model file so that I could take a look at this unusual "internal structure" myself? Or if it's trained on proprietary data, can you reproduce the pruning error using some publicly available toy dataset?

vruusmann commented 3 years ago

As a workaround, you should dsable tree pruning by specifying the prune = False conversion option:

pipeline = PMMLPipeline([
  ("xgb", XGBClassifier())
])
pipeline.fit(X, y)
# THIS - specify conversion options right after fitting the pipeline
pipeline.configure(prune = False)

sklearn2pmml(pipeline, "XGBoost.pmml")
xiaoqiao21 commented 3 years ago

Thanks a lot! It works now