Closed alexnikitchuk closed 5 years ago
You can customize any model prior to passing it to model selector as follows:
// Copy over all the default grid and model types provided
// while setting the weight column for logistic regression model only
val modelsAndParams = BinaryClassificationModelSelector.Defaults.modelsAndParams.map {
case (lg: OpLogisticRegression, grid) => lg.setWeightCol("myWeigthColumn") -> grid.build()
case (m, grid) => m -> grid.build()
}
val modelSelector = BinaryClassificationModelSelector
.withCrossValidation(modelsAndParameters = modelsAndParams)
.setInput(label, features)
Alternatively you can fully override the param grid if needed:
val lr = new OpLogisticRegression().setWeightCol("myWeigthColumn")
val lrParams = new ParamGridBuilder()
.addGrid(lr.elasticNetParam, Array(1.0))
.addGrid(lr.maxIter, Array(10))
.addGrid(lr.regParam, Array(1000000.0, 0.0))
.build()
val modelSelector = BinaryClassificationModelSelector
.withCrossValidation(modelsAndParameters = Seq(lr -> lrParams))
.setInput(label, features)
Ok, that's clear. Now the question - how to make myWeigthColumn
survive transmogrification, sanity checking and raw feature filtering, so that LogReg model could actually find it in the input dataframe?
RawFeatureFilter
has protectedFeatures
property that allows you to avoid specific raw features being removed and SanityChecker
wont remove features which are not part of the feature vector.
As far as I remember we've tried that yesterday but got:
Caused by: java.lang.IllegalArgumentException: Field "classWeight" does not exist.
Available fields: label, featureVector, key
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:267)
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:267)
at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
at scala.collection.AbstractMap.getOrElse(Map.scala:59)
at org.apache.spark.sql.types.StructType.apply(StructType.scala:266)
at org.apache.spark.ml.util.SchemaUtils$.checkNumericType(SchemaUtils.scala:71)
at org.apache.spark.ml.PredictorParams$class.validateAndTransformSchema(Predictor.scala:58)
at org.apache.spark.ml.classification.Classifier.org$apache$spark$ml$classification$ClassifierParams$$super$validateAndTransformSchema(Classifier.scala:58)
at org.apache.spark.ml.classification.ClassifierParams$class.validateAndTransformSchema(Classifier.scala:42)
at org.apache.spark.ml.classification.ProbabilisticClassifier.org$apache$spark$ml$classification$ProbabilisticClassifierParams$$super$validateAndTransformSchema(ProbabilisticClassifier.scala:53)
at org.apache.spark.ml.classification.ProbabilisticClassifierParams$class.validateAndTransformSchema(ProbabilisticClassifier.scala:37)
at org.apache.spark.ml.classification.LogisticRegression.org$apache$spark$ml$classification$LogisticRegressionParams$$super$validateAndTransformSchema(LogisticRegression.scala:278)
at org.apache.spark.ml.classification.LogisticRegressionParams$class.validateAndTransformSchema(LogisticRegression.scala:265)
at org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema(LogisticRegression.scala:278)
at org.apache.spark.ml.Predictor.transformSchema(Predictor.scala:144)
at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:100)
at com.salesforce.op.stages.sparkwrappers.specific.OpPredictorWrapper.fit(OpPredictorWrapper.scala:99)
at com.salesforce.op.stages.sparkwrappers.specific.OpPredictorWrapper.fit(OpPredictorWrapper.scala:67)
at org.apache.spark.ml.Estimator.fit(Estimator.scala:61)
at org.apache.spark.ml.Estimator$$anonfun$fit$1.apply(Estimator.scala:82)
at org.apache.spark.ml.Estimator$$anonfun$fit$1.apply(Estimator.scala:82)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.ml.Estimator.fit(Estimator.scala:82)
at com.salesforce.op.stages.impl.tuning.OpValidator$$anonfun$9$$anonfun$apply$2.apply(OpValidator.scala:293)
at com.salesforce.op.stages.impl.tuning.OpValidator$$anonfun$9$$anonfun$apply$2.apply(OpValidator.scala:289)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Seems like these are the only columns which passed to Spark LogisticRegression with:
val prediction = BinaryClassificationModelSelector
.withCrossValidation(
splitter = Some(dataBalancer),
validationMetric = LogLoss.binaryLogLoss,
trainTestEvaluators = Seq(LogLoss.binaryLogLoss),
modelTypesToUse = Seq(MTT.OpLogisticRegression),
stratify = false,
modelsAndParameters = Seq(lr -> lrParams.build()),
seed = randomSeed
)
.setInput(label -> featureVector)
.getOutput()
val model = new OpWorkflow()
.setResultFeatures(label, prediction)
.withRawFeatureFilter(trainingReader = Some(trainReader), scoringReader = Some(scoringReader), protectedFeatures = Array(classWeight))
.train()
Oh, I see. Spark LogisticRegression
stage actually expects the weightCol
param be a column of the input dataframe. Unfortunately we currently only pass labelCol
and featuresCol
for the modeling stages. Sounds like a great feature to add to TransmogrifAI then ;)
Thanks @tovbinm. And on the other hand how do you deal with class imbalance in this case?
You can use DataBalancer
to rebalance your dataset prior to modeling:
val dataBalancer = DataBalancer(sampleFraction = 0.2)
BinaryClassificationModelSelector.withCrossValidation(splitter = Some(dataBalancer))
Or DataCutter
for multi class problems:
val dataCutter = DataCutter(maxLabelCategories = 50, minLabelFraction = 0.1)
MultiClassificationModelSelector.withCrossValidation(splitter = Some(dataCutter))
Problem We have imbalanced dataset and want to use LogisticRegression
weightCol
but no luck so far. Is there any way to use this functionality?Solution Unknown
Alternatives Using bare Spark LogisticRegression
Additional context