jpmml / jpmml-sparkml

Java library and command-line application for converting Apache Spark ML pipelines to PMML
GNU Affero General Public License v3.0
267 stars 80 forks source link

MultilayerPerceptronClassificationModel IllegalArgumentException("Expected 3 target categories, got 2 target categories"); #115

Closed laotang123 closed 1 year ago

laotang123 commented 3 years ago

fixed:

org.jpmml.sparkml.ModelConverter class , method encodeSchema //author: liujunfeng 修复多层感知机label类别和数量识别不对 if (model instanceof MultilayerPerceptronClassificationModel) { MultilayerPerceptronClassificationModel classificationModel = (MultilayerPerceptronClassificationModel) model; int[] layers = classificationModel.layers(); numClasses = layers[layers.length - 1]; }

vruusmann commented 1 year ago

The number of target categories is determined by the associated StringIndexerModel object, and is double-checked against the ClassificationModel#numClasses property.

Starting from JPMML-SparkML version 2.X, the latter can be "customized" by overriding the ClassificationModelConverter#getNumClasses() method: https://github.com/jpmml/jpmml-sparkml/blob/2.0.2/pmml-sparkml/src/main/java/org/jpmml/sparkml/ClassificationModelConverter.java#L52-L56

I personally haven't noticed any irregularities in this area when working with MLP classifiers.