jpmml / jpmml-sparkml

Java library and command-line application for converting Apache Spark ML pipelines to PMML
GNU Affero General Public License v3.0
267 stars 80 forks source link

Confusion of Softmax and Logit #71

Closed dlepzelter closed 5 years ago

dlepzelter commented 5 years ago

In trying to implement a Gradient-Boosted Tree model I trained in Spark via JPMML, I ended up finding that the probabilities it output were not the same as those in my model. On looking further into this, it seems to be because of confusion over the term "logit."

The ordinal logit function, as defined in http://dmg.org/pmml/v4-2-1/Regression.html, is the inverse of the softmax function. "Logit normalization" is designed to take a logit (unnormalized log probability) and turn it into a normalized probability, and this is done using the softmax function. Thus, the "logit normalization" in Spark should really be translated to the softmax function instead of the ordinal logit.

vruusmann commented 5 years ago

What's your Apache Spark version, what's the task (binary classification, multi-class classification?), and what's your GBTClassifier configuration?

The JPMML-SparkML includes an integration test for the binary classification case. And in that case the probabilities are correct within 1e-13 absolute and relative error: https://github.com/jpmml/jpmml-sparkml/blob/master/src/test/resources/csv/GBTAudit.csv

dlepzelter commented 5 years ago

PySpark 2.4.3, task is binary classification, GBTClassifier is mostly using defaults. I've included the relevant line here: gbt = GBTClassifier().setLabelCol("label").setFeaturesCol("features").setMaxIter(20).setMaxDepth(7).setMaxBins(35)

That said, I've noticed that despite the task being binary classification, it may essentially treat it as multi-class; it has two different probabilities predicted instead of one, and the probabilities is gives me on transformation are those that would arise from using the multi-class version of softmax instead of the binary one.

vruusmann commented 5 years ago

I'd like to see a reproducible example about mis-predicted probabilities. You can use my Audit.csv file as as a toy dataset (Adjusted ~ .): https://github.com/jpmml/jpmml-sparkml/blob/master/src/test/resources/csv/Audit.csv

The PMML representation (as implemented in the JPMML ecosystem) is about functional equivalence with the original ML framework. Sometimes it makes sense to replace softmax with logit, when it leads to considerably more compact/readable model representation.

I can see that you're using the term "ordinal" in the original comment. PMML has an "ordinal" operational type, which is currently not implemented in JPMML-SparkML. Perhaps you've been talking about "probabilities of an ordinal target field" whereas I've been talking about "probabilities of a categorical target field" (in the context of a binary classification task)?

dlepzelter commented 5 years ago

I'll apply my script to that data when I can, and see what comes out. It may be a little while; I've got several balls in the air at the moment, as it were.

dlepzelter commented 5 years ago

Currently getting the following: IllegalArgumentException: 'Field Employment has valid values [Private, Consultant, PSLocal, SelfEmp, PSState, PSFederal, Volunteer]' I'm using pyspark2pmml at the moment (apologies for using the wrong board, though I know you do both).

from pyspark import SparkContext, SparkConf
conf = SparkConf().setAppName('appName').setMaster('local')
sc = SparkContext(conf=conf)
from pyspark.sql import SparkSession, column, functions
from pyspark.ml.feature import StringIndexer

from pyspark.ml.linalg import SparseVector,VectorUDT

# Transformers to prepare the features for training
from pyspark.ml.feature import VectorAssembler, VectorIndexer

# Algorithms (Estimators) to train our models (Transformers)
from pyspark.ml.classification import RandomForestClassifier,DecisionTreeClassifier,GBTClassifier

# To Evaluate the model against the test data
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Pipeline
from pyspark.ml import Pipeline

import numpy as np
sqc = SparkSession(sc)
df=sqc.read.option("header", "true").option("inferSchema","true").csv("/Users/dlepzelt/Documents/testdata.csv")
boostedTrees = GBTClassifier().setLabelCol("Adjusted").setFeaturesCol("features").setMaxIter(20).setMaxDepth(7).setMaxBins(35)
strcols = ["Employment","Education","Marital","Occupation","Gender"]
stind = []
for col in strcols:
    stind.append(StringIndexer(inputCol=col,outputCol=col+"Indexed"))
collist = [col for col in df.columns if col not in strcols+["Adjusted"]] + [col+"Indexed" for col in strcols]
vectorAssembler = VectorAssembler().setInputCols(collist).setOutputCol("rawFeatures")
vectorIndexer =  VectorIndexer().setInputCol("rawFeatures").setOutputCol("features").setMaxCategories(35)
p = Pipeline().setStages(stind+[vectorAssembler,vectorIndexer,boostedTrees])
pipe = p.fit(df)
#pipe.save("~/Documents/GBTModelforreadmission")
pipe.transform(df).drop("rawFeatures").drop("features").write.parquet("~/Documents/GBTSample_gitdata.parquet")
from pyspark2pmml import PMMLBuilder
pmmlBuilder = PMMLBuilder(sc, df, pipe)
pmmlBuilder.buildFile("/Users/dlepzelt/Documents/gbt_gitdata.pmml")

The error comes after the pipeline is fit and the dataframe transformed; it's only in the pmmlBuilder that it has issues.

dlepzelter commented 5 years ago

Any thoughts on that more recent error? Also, as to the "ordinal" vs. whatever else... what I know is, there's a Spark version, a JPMML version, and a version that gets translated into another language from the PMML, and the two latter ones don't end up the same as the Spark one but do end up agreeing with each other. My assumption, given that the other version seemed to be trying to apply the "ordinal" definition from PMML, was that this was the issue... otherwise, something else is funky.

vruusmann commented 5 years ago

IllegalArgumentException: 'Field Employment has valid values [Private, Consultant, PSLocal, SelfEmp, PSState, PSFederal, Volunteer]'

Duplicate of https://github.com/jpmml/jpmml-sparkml/issues/73

dlepzelter commented 5 years ago

For what it's worth, I've gone back and redesigned the pipeline to allow it to go through normally, and found that the problem is not, in fact, where I'd been assured it was. It seems to be elsewhere, possibly in versioning with a private library built on top of JPMML. My apologies; I'm closing this.