jpmml / jpmml-sparkml

Java library and command-line application for converting Apache Spark ML pipelines to PMML
GNU Affero General Public License v3.0
267 stars 80 forks source link

spark gbtmodel Segmentation, MiningField as feature? #47

Closed haofengrushui204 closed 5 years ago

haofengrushui204 commented 6 years ago

hello, I pmml as fllows ,i do not know why “label” is usageType="target" in MiningSchema, but "label‘’ is active in MiningModel/Segmentation[segment@id=1]?


<DataDictionary>
        <DataField name="label" optype="categorical" dataType="double">
            <Value value="0.0"/>
            <Value value="1.0"/>
        </DataField>
        <DataField name="feature_442" optype="continuous" dataType="double"/>
        <DataField name="feature_443" optype="continuous" dataType="double"/>
        <DataField name="feature_481" optype="continuous" dataType="double"/>
        <DataField name="feature_894" optype="continuous" dataType="double"/>
        <DataField name="feature_1862" optype="continuous" dataType="double"/>
    </DataDictionary>
    <MiningModel functionName="classification">
        <MiningSchema>
            <MiningField name="label" usageType="target"/>
            <MiningField name="feature_442"/>
            <MiningField name="feature_443"/>
            <MiningField name="feature_481"/>
            <MiningField name="feature_894"/>
            <MiningField name="feature_1862"/>
        </MiningSchema>
        <Segmentation multipleModelMethod="modelChain">
            <Segment id="1">
                <True/>
                <MiningModel functionName="regression">
                    <MiningSchema>
                        <MiningField name="feature_442"/>
                        <MiningField name="feature_443"/>
                        <MiningField name="feature_481"/>
                        <MiningField name="feature_894"/>
                        <MiningField name="feature_1862"/>
                        <MiningField name="label"/>
                    </MiningSchema>
                    <Output>
                        <OutputField name="gbtValue" optype="continuous" dataType="double" feature="predictedValue" isFinalResult="false"/>
                        <OutputField name="binarizedGbtValue" optype="continuous" dataType="double" feature="transformedValue" isFinalResult="false">
                            <Apply function="if">
                                <Apply function="greaterThan">
                                    <FieldRef field="gbtValue"/>
                                    <Constant dataType="double">0</Constant>
                                </Apply>
                                <Constant dataType="double">-1</Constant>
                                <Constant dataType="double">1</Constant>
                            </Apply>
                        </OutputField>
                    </Output>
                    <Segmentation multipleModelMethod="sum">
                        <Segment id="1">
                            <True/>
                            <TreeModel functionName="regression" splitCharacteristic="binarySplit">
                                <MiningSchema>
                                    <MiningField name="label"/>
                                </MiningSchema>
                                <Node score="-0.08980349484734046">
                                    <True/>
                                    <Node score="-1">
                                        <SimplePredicate field="label" operator="lessOrEqual" value="0"/>
                                    </Node>
                                    <Node score="1">
                                        <SimplePredicate field="label" operator="greaterThan" value="0"/>
                                    </Node>
                                </Node>
                            </TreeModel>
                        </Segment>
                        <Segment id="2">
                            <True/>
                            <TreeModel functionName="regression" splitCharacteristic="binarySplit">
                                <MiningSchema>
                                    <MiningField name="feature_442"/>
                                    <MiningField name="feature_443"/>
                                    <MiningField name="feature_481"/>
                                    <MiningField name="feature_894"/>
                                    <MiningField name="feature_1862"/>
                                    <MiningField name="label"/>
                                </MiningSchema>
                                <Targets>
                                    <Target rescaleFactor="0.1"/>
                                </Targets>
                                <Node score="-0.04281935597440249">
                                    <True/>
                                    <Node score="-0.47681168808845653">
                                        <SimplePredicate field="label" operator="lessOrEqual" value="0"/>
                                        <Node score="-0.47681168808847174">
                                            <SimplePredicate field="feature_442" operator="lessOrEqual" value="-0.5888127277121523"/>
                                            <Node score="-0.4768116880884725">
                                                <SimplePredicate field="feature_894" operator="lessOrEqual" value="-0.6830283900955506"/>
                                            </Node>
                                            <Node score="-0.47681168808847285">
                                                <SimplePredicate field="feature_894" operator="greaterThan" value="-0.6830283900955506"/>
                                            </Node>
                                        </Node>
                                        <Node score="-0.47681168808847096">
                                            <SimplePredicate field="feature_442" operator="greaterThan" value="-0.5888127277121523"/>
                                            <Node score="-0.47681168808847013">
                                                <SimplePredicate field="feature_443" operator="lessOrEqual" value="-1.2352702594745397"/>
                                            </Node>
                                            <Node score="-0.4768116880884723">
                                                <SimplePredicate field="feature_443" operator="greaterThan" value="-1.2352702594745397"/>
                                            </Node>
                                        </Node>
                                    </Node>
                                    <Node score="0.47681168808845853">
                                        <SimplePredicate field="label" operator="greaterThan" value="0"/>
                                        <Node score="0.47681168808846963">
                                            <SimplePredicate field="feature_1862" operator="lessOrEqual" value="-1.38258310890975"/>
                                            <Node score="0.4768116880884702">
                                                <SimplePredicate field="feature_481" operator="lessOrEqual" value="-1.128558484240802"/>
                                            </Node>
                                            <Node score="0.4768116880884703">
                                                <SimplePredicate field="feature_481" operator="greaterThan" value="-1.128558484240802"/>
                                            </Node>
                                        </Node>
                                        <Node score="0.47681168808847163">
                                            <SimplePredicate field="feature_1862" operator="greaterThan" value="-1.38258310890975"/>
                                        </Node>
                                    </Node>
                                </Node>
                            </TreeModel>
                        </Segment>
                    </Segmentation>
                </MiningModel>
            </Segment>
...
vruusmann commented 6 years ago

Can you show me your Apache Spark pipeline definition?

You appear to be using a model chain. The first model is GBT (Segment@id=1), which is then followed by some other model (deleted from the above PMML snippet).

The most likely explanation is that you're using label as input column to GBT.

haofengrushui204 commented 6 years ago

OK, thanks for your replying. this is my code:

  def train(trainRDD: RDD[LabeledPoint], iterNum: Int, spark: SparkSession): Unit = {
    import spark.implicits._

    val features_size = trainRDD.take(1)(0).features.size
    val featureNames = (0 until features_size).map(x => s"feature_$x") :+ "label"
    val schemaAnalysis = featureNames.map(featureName => StructField(featureName, DoubleType)).toArray

    val trainDFTmp = trainRDD.map { case LabeledPoint(label, features) =>
      val seq: Seq[Double] = features.toArray.toSeq
      Row.fromSeq(seq :+ label)
    }
    val trainDF = spark.createDataFrame(trainDFTmp, new StructType(schemaAnalysis))

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")

    val vectorAssember = new VectorAssembler()
      .setInputCols(featureNames.toArray)
      .setOutputCol("features")

//    val featuresIndexer = new VectorIndexer()
//      .setInputCol("features")
//      .setOutputCol("indexedFeatures")
//      .setMaxCategories(10)

    val gbt = new GBTClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("features")
      .setMaxDepth(3)
      .setMaxIter(iterNum)

    // Convert indexed labels back to original labels.
    //    val labelConverter = new IndexToString()
    //      .setInputCol("prediction")
    //      .setOutputCol("predictedLabel")
    //      .setLabels(labelIndexer.labels)

    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, vectorAssember, gbt))
    val pipelineModel = pipeline.fit(trainDF)
//    HDFSUtils.deleteDir("hdfs://bitautodmp/data/datamining/ctr/model_gbdt")
//    pipelineModel.save("hdfs://bitautodmp/data/datamining/ctr/model_gbdt")
//    val pipelineModel = PipelineModel.load("hdfs://bitautodmp/data/datamining/ctr/model_gbdt")

    /**
      * write model pmml format to hdfs
      */
    val pmml = ConverterUtil.toPMML(trainDF.schema, pipelineModel)
    HDFSUtils.deleteDir(modelPmmlPath)
    val fs: FileSystem = FileSystem.get(new Configuration())
    val path = new Path(modelPmmlPath)
    val out = fs.create(path)
    MetroJAXBUtil.marshalPMML(pmml, out)
}
vruusmann commented 6 years ago

val featureNames = (0 until featuressize).map(x => s"feature$x") :+ "label"

I'm not very familiar with the Scala language, but is it possible that the above code line is building a collection of strings, where the last element is "label"?

This collection is then used to define VectorAssembler input columns, which kind of explains how/why the GBT model gets to include the "label" column as regular input column.

vruusmann commented 6 years ago

It would be nice if the JPMML-SparkML library performed some additional sanity checks on the model schema definition - it should throw an exception if the same column is used as a label and a feature.

haofengrushui204 commented 6 years ago

ok, thanks very much, as you said, i make a mistake.