jpmml / jpmml-sparkml

Java library and command-line application for converting Apache Spark ML pipelines to PMML
GNU Affero General Public License v3.0
267 stars 80 forks source link

SQLTransformer cannot find fields #109

Closed zwag20 closed 3 years ago

zwag20 commented 3 years ago

I am having issues with the pmmlbuilder and SQLTransformer. I am trying to replace nulls with a 0 in a pipeline and thought this might be the easiest way to accomplish it. I can export the pmml fine when I don't include the SQLTransfomer, but I assume I have a mistake somewhere. Here is some sample code I put together. I am running this on Databricks if that matters.

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StringIndexer,VectorAssembler,SQLTransformer
from pyspark2pmml import PMMLBuilder, toPMMLBytes

# Prepare training documents from a list of (id, text, label) tuples.
training = spark.createDataFrame([
    (0, "abc",3, 1.0),
    (1, "b",None, 0.0),
    (2, "spark",8, 1.0),
    (3, "hadoop",4, 0.0)
], ["id", "category","numcol", "label"])

# Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
fillna = SQLTransformer(statement = 
"""select 
  category,
  case when numcol is null then 0 else numcol end as numcol,
  label
FROM __THIS__
""")
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
assembler = VectorAssembler(inputCols=["categoryIndex", "numcol"], outputCol="features")
rf =RandomForestClassifier(labelCol="label", featuresCol="features",numTrees=5,maxDepth=3)
pipeline = Pipeline(stages=[fillna, indexer, assembler, rf])

model = pipeline.fit(training)

pmmlBuilder = PMMLBuilder(sc,training,model)
pmmlBuilder.buildFile("/dbfs/tmp/test.pmml")
vruusmann commented 3 years ago

I'm running this code example using JPMML-SparkML 1.6.3 on Apache Spark 3.0:

$ $SPARK_HOME/bin/pyspark --packages org.jpmml:jpmml-sparkml:1.6.3

This is the error that I am getting:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python2.7/site-packages/pyspark2pmml/__init__.py", line 21, in build
    return self.javaPmmlBuilder.build()
  File "/opt/spark-3.0.0/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py", line 1305, in __call__
  File "/opt/spark-3.0.0/python/pyspark/sql/utils.py", line 137, in deco
    raise_from(converted)
  File "/opt/spark-3.0.0/python/pyspark/sql/utils.py", line 33, in raise_from
    raise e
pyspark.sql.utils.IllegalArgumentException: Name(s) [numcol] do not match any fields

I can't find a substring "do not match" anywhere in the JPMML-SparkML library code. Perhaps it's a PySpark-level exception, not a JPMML-SparkML-level exception after all?

Should translate this example code to Scala and run without the PySpark intermediate layer. This should reveal a more meaningful exception.

vruusmann commented 3 years ago

Translated this code example to Scala:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.feature._
import org.apache.spark.sql._
import org.apache.spark.sql.types._

val data = List(
    Row(0, "abc", 3, 1.0),
    Row(1, "b", null, 0.0),
    Row(2, "spark", 8, 1.0),
    Row(3, "hadoop", 4, 0.0)
)

val rdd = spark.sparkContext.parallelize(data)

val schema = StructType(Array(
    StructField("id", IntegerType, true),
    StructField("category", StringType, true),
    StructField("numcol", IntegerType, true),
    StructField("label", DoubleType, true)
))

val training = spark.createDataFrame(rdd, schema)

val fillna = new SQLTransformer().setStatement("select category, case when numcol is null then 0 else numcol end as imputed_numcol, label FROM __THIS__")
val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
val assembler = new VectorAssembler().setInputCols(Array("categoryIndex", "imputed_numcol")).setOutputCol("features")
val rf = new RandomForestClassifier().setNumTrees(5).setMaxDepth(3).setLabelCol("label").setFeaturesCol("features")

val pipeline = new Pipeline().setStages(Array(fillna, indexer, assembler, rf))
val pipelineModel = pipeline.fit(training)

import org.jpmml.sparkml.PMMLBuilder

val pmmlBuilder = new PMMLBuilder(training.schema, pipelineModel)
pmmlBuilder.build()

Running with the spark-shell command-line application:

$ $SPARK_HOME/bin/spark-shell --packages org.jmml-sparkml:1.6.3

This is the underlying exception:

java.lang.IllegalArgumentException: Name(s) [numcol] do not match any fields
  at org.jpmml.converter.visitors.FieldUtil.selectAll(FieldUtil.java:84)
  at org.jpmml.converter.visitors.FieldUtil.selectAll(FieldUtil.java:60)
  at org.jpmml.converter.visitors.FieldDependencyResolver.process(FieldDependencyResolver.java:226)
  at org.jpmml.converter.visitors.FieldDependencyResolver.visit(FieldDependencyResolver.java:97)
  at org.dmg.pmml.DerivedField.accept(DerivedField.java:240)
  at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:108)
  at org.dmg.pmml.TransformationDictionary.accept(TransformationDictionary.java:108)
  at org.dmg.pmml.PMMLObject.traverse(PMMLObject.java:90)
  at org.dmg.pmml.PMML.accept(PMML.java:237)
  at org.jpmml.model.visitors.AbstractVisitor.applyTo(AbstractVisitor.java:320)
  at org.jpmml.converter.visitors.DeepFieldResolver.applyTo(DeepFieldResolver.java:40)
  at org.jpmml.model.visitors.VisitorBattery.applyTo(VisitorBattery.java:26)
  at org.jpmml.converter.ModelEncoder.encodePMML(ModelEncoder.java:77)
  at org.jpmml.sparkml.PMMLBuilder.build(PMMLBuilder.java:212)
  ... 53 elided
vruusmann commented 3 years ago

I can't find a substring "do not match" anywhere in the JPMML-SparkML library code. Perhaps it's a PySpark-level exception, not a JPMML-SparkML-level exception after all?

This exception is raised by the underlying JPMML-Converter library.

However, the bug probably resides somewhere in the JPMML-SparkML library level.

vruusmann commented 3 years ago

It should be noted that in my Scala example I renamed the transformed column from numcol to imputed_numcol. I was thinking that perhaps SQLTransformer has problem renaming columns in place. But that part seems to be okay.

zwag20 commented 3 years ago

Is this a bug that is fixable, or should I pursue another solution?

vruusmann commented 3 years ago

@zwag20 The bug was fixed one minute before you asked about it's fixability. However, it's currently located in the 1.4.X branch (Apache Spark 2.3), and there's quite a lot of merging work to do before it reaches the master branch.

In brief, this issue was caused by a fact that the feature numcol was referenced at the "inner level" (here, inside the CaseWhen expression), but it wasn't referenced at the "outer level" (here, SELECT numcol ...).

zwag20 commented 3 years ago

Thanks for the update @vruusmann

vruusmann commented 3 years ago

@zwag20 Here's a workaround that works with current JPMML-SparkML library versions:

  1. In SQL statement, make a top-level reference to the numcol field, and give the transformed field a different name such as numcol_imputed.
  2. In vector assembler, only reference the transformed field numcol_imputed

Code example:

fillna = SQLTransformer(statement = 
"""select 
  category,
  numcol,
  case when numcol is null then 0 else numcol end as imputed_numcol,
  label
FROM __THIS__
""")
assembler = VectorAssembler(inputCols=["categoryIndex", "imputed_numcol"], outputCol="features")
zwag20 commented 3 years ago

Thanks for the workaround.

I have done some testing where I exported the pmml with your workaround and another test where I manually changed the none to 0 and removed the fillna and sqlTransformer. They both ran successfully, but the pmml files are identical. Is this supposed to be the case? I don't know much about pmmls, but I would expect there to be code handling the nulls within the pmml. Maybe I am mistaken.

vruusmann commented 3 years ago

The CaseWhen expression causes the following PMML snippet to be generated:

<DerivedField name="imputed_numcol" optype="continuous" dataType="integer">
    <Apply function="if">
        <Apply function="isMissing">
            <FieldRef field="numcol"/>
        </Apply>
        <Constant dataType="integer">0</Constant>
        <FieldRef field="numcol"/>
    </Apply>
</DerivedField>

Please note that JPMML-SparkML performs "model cleansing" automatically, where unused field definitions are removed from the PMML document.

So, if the numcol_imputed feature is not significant, there won't be a corresponding DerivedField element.

zwag20 commented 3 years ago

Excellent!

That was my issue

Thanks for all of your help!