microsoft / SynapseML

Simple and Distributed Machine Learning
http://aka.ms/spark
MIT License
5.07k stars 833 forks source link

ClassCastException with LgbmRanker train #823

Open arijeetm1 opened 4 years ago

arijeetm1 commented 4 years ago

Describe the bug java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double on training.

To Reproduce

model = LightGBMRanker(
                boostingType = 'gbdt',
                objective = 'lambdarank',
                maxPosition=50,
                isProvideTrainingMetric=True,
                maxBin = 255,
                evalAt=[10],
                numIterations = 500,
                learningRate = 0.3,
                numLeaves = 127,
                earlyStoppingRound = 20,
                #parallelism = 'serial',
                #num_threads = 8
                featureFraction = 0.5,
                baggingFreq = 1,
                baggingFraction = 0.8,
                #min_data_in_leaf = 20 
                minSumHessianInLeaf = 0.001,
                #is_enable_sparse = True,
                #use_two_round_loading = True,
                #is_save_binary_file = False,
                groupCol='label',
                labelGain=[0.0,1.0,3.0,7.0,15.0],
                categoricalSlotIndexes=[4,5,6,7,8,9,10,11,12,13,14,15,16]
).fit(data)

data.schema StructType(List(StructField(label,DoubleType,true),StructField(features,VectorUDT,true)))

Expected behavior A clear and concise description of what you expected to happen.

Info (please complete the following information):

Stacktrace Spark history server logs:

java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double
    at scala.runtime.BoxesRunTime.unboxToDouble(BoxesRunTime.java:114)
    at org.apache.spark.sql.Row$class.getDouble(Row.scala:248)
    at org.apache.spark.sql.catalyst.expressions.GenericRow.getDouble(rows.scala:166)
    at com.microsoft.ml.spark.lightgbm.TrainUtils$$anonfun$3.apply(TrainUtils.scala:29)
    at com.microsoft.ml.spark.lightgbm.TrainUtils$$anonfun$3.apply(TrainUtils.scala:29)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
    at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
    at com.microsoft.ml.spark.lightgbm.TrainUtils$.generateDataset(TrainUtils.scala:29)
    at com.microsoft.ml.spark.lightgbm.TrainUtils$.translate(TrainUtils.scala:233)
    at com.microsoft.ml.spark.lightgbm.TrainUtils$.trainLightGBM(TrainUtils.scala:385)
    at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145)
    at com.microsoft.ml.spark.lightgbm.LightGBMBase$$anonfun$6.apply(LightGBMBase.scala:145)
    at org.apache.spark.sql.execution.MapPartitionsExec$$anonfun$5.apply(objects.scala:188)
    at org.apache.spark.sql.execution.MapPartitionsExec$$anonfun$5.apply(objects.scala:185)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:55)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

If the bug pertains to a specific feature please tag the appropriate CODEOWNER for better visibility

Additional context We could rule out this code where we are getDouble from label : https://github.com/Azure/mmlspark/blob/master/src/main/scala/com/microsoft/ml/spark/lightgbm/TrainUtils.scala#L98 since value is of type vector.

Looking through the spark sql query plan gives us an possible explanation with labels, where we cast the label to int during projection and then attempt to cast it to double during deserialize which could be related to this issue.

Screen Shot 2020-03-10 at 11 20 18 PM
imatiach-msft commented 4 years ago

based on the line number in the stack trace:

    at com.microsoft.ml.spark.lightgbm.TrainUtils$$anonfun$3.apply(TrainUtils.scala:29)

it looks like this error is due to this line:

val labels = rows.map(row => row.getDouble(schema.fieldIndex(columnParams.labelColumn)))

it looks like your label column doesn't have doubles. I think this should already be fixed in the latest master by the castColumns method in file LightGBMBase.scala.

imatiach-msft commented 4 years ago

interestingly, this would actually violate the schema in your dataset above:

 StructType(List(StructField(label,DoubleType,true),StructField(features,VectorUDT,true)))

so I'm not quite sure what is happening there

silver6wings commented 4 years ago

I came across the same exception when featureFraction is Integer, but should not the root cause for this case. Maybe be some float parameter not allow pass Integer in?

hebo-yang commented 3 years ago

`java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double

Py4JJavaError Traceback (most recent call last)

in 5 numIterations=380, 6 lambdaL2=100 ----> 7 ).fit(train_data) /databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params) 130 return self.copy(params)._fit(dataset) 131 else: --> 132 return self._fit(dataset) 133 else: 134 raise ValueError("Params must be either a param map or a list/tuple of param maps, " /databricks/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset) 293 294 def _fit(self, dataset): --> 295 java_model = self._fit_java(dataset) 296 model = self._create_model(java_model) 297 return self._copyValues(model) /databricks/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset) 289 :return: fitted Java model 290 """ --> 291 self._transfer_params_to_java() 292 return self._java_obj.fit(dataset._jdf) 293 /databricks/spark/python/pyspark/ml/wrapper.py in _transfer_params_to_java(self) 122 for param in self.params: 123 if self.isSet(param): --> 124 pair = self._make_java_param_pair(param, self._paramMap[param]) 125 self._java_obj.set(pair) 126 if self.hasDefault(param): /databricks/spark/python/pyspark/ml/wrapper.py in _make_java_param_pair(self, param, value) 113 java_param = self._java_obj.getParam(param.name) 114 java_value = _py2java(sc, value) --> 115 return java_param.w(java_value) 116 117 def _transfer_params_to_java(self): /databricks/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args) 1255 answer = self.gateway_client.send_command(command) 1256 return_value = get_return_value( -> 1257 answer, self.gateway_client, self.target_id, self.name) 1258 1259 for temp_arg in temp_args: /databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 61 def deco(*a, **kw): 62 try: ---> 63 return f(*a, **kw) 64 except py4j.protocol.Py4JJavaError as e: 65 s = e.java_exception.toString() /databricks/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". --> 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError( Py4JJavaError: An error occurred while calling o907.w. : java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Double at scala.runtime.BoxesRunTime.unboxToDouble(BoxesRunTime.java:114) at org.apache.spark.ml.param.DoubleParam.w(params.scala:330) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:295) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:251) at java.lang.Thread.run(Thread.java:748) `
imatiach-msft commented 3 years ago

@hebo-yang oh this is very interesting. Based on that stack trace it looks like there is some param that is being converted to an int instead of a double in java from python, but the scala code expect it to be a double values param. I think this is something in the pyspark bindings then. I will try to run this with the params you specified on a different dataset to see if I can reproduce.

hebo-yang commented 3 years ago

@imatiach-msft Thanks! Were you able to repro this please?

hebo-yang commented 3 years ago

Finally figured out my error...omg, the scala code expects parameters like alpha and lambdaL2 to be floats. Mine happen to be an int value and thus the error.