dvgodoy / handyspark

HandySpark - bringing pandas-like capabilities to Spark dataframes
MIT License
185 stars 23 forks source link

Confusion martrix not working with DataBricks 7.4 ML level of spark #25

Closed DrYSG closed 2 years ago

DrYSG commented 3 years ago

Using DataBricks v7.4 ML cluster, with Spark 3.01

(also I think you are using MultiClassMetrics to get the AUC value, but that seems to be high when one has a LR model. See: https://stackoverflow.com/questions/60772315/how-to-evaluate-a-classifier-with-apache-spark-2-4-5-and-pyspark-python )

I get this error when trying to do the confusion matrix:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<command-3318334168327659> in <module>
----> 1 bcm.print_confusion_matrix(.572006)

/databricks/python/lib/python3.7/site-packages/handyspark/extensions/evaluation.py in print_confusion_matrix(self, threshold)
    111     confusionMatrix: pd.DataFrame
    112     """
--> 113     cm = self.confusionMatrix(threshold).toArray()
    114     df = pd.concat([pd.DataFrame(cm)], keys=['Actual'], names=[])
    115     df.columns = pd.MultiIndex.from_product([['Predicted'], df.columns])

/databricks/python/lib/python3.7/site-packages/handyspark/extensions/evaluation.py in confusionMatrix(self, threshold)
     92     """
     93     scoreAndLabels = self.call2('scoreAndLabels').map(lambda t: (float(t[0] > threshold), t[1]))
---> 94     mcm = MulticlassMetrics(scoreAndLabels)
     95     return mcm.confusionMatrix()
     96 

/databricks/spark/python/pyspark/mllib/evaluation.py in __init__(self, predictionAndLabels)
    254         sc = predictionAndLabels.ctx
    255         sql_ctx = SQLContext.getOrCreate(sc)
--> 256         numCol = len(predictionAndLabels.first())
    257         schema = StructType([
    258             StructField("prediction", DoubleType(), nullable=False),

/databricks/spark/python/pyspark/rdd.py in first(self)
   1491         ValueError: RDD is empty
   1492         """
-> 1493         rs = self.take(1)
   1494         if rs:
   1495             return rs[0]

/databricks/spark/python/pyspark/rdd.py in take(self, num)
   1473 
   1474             p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
-> 1475             res = self.context.runJob(self, takeUpToNumLeft, p)
   1476 
   1477             items += res

/databricks/spark/python/pyspark/context.py in runJob(self, rdd, partitionFunc, partitions, allowLocal)
   1225             finally:
   1226                 os.remove(filename)
-> 1227         sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
   1228         return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
   1229 

/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
    125     def deco(*a, **kw):
    126         try:
--> 127             return f(*a, **kw)
    128         except py4j.protocol.Py4JJavaError as e:
    129             converted = convert_exception(e.java_exception)

/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2131.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2131.0 (TID 9926, ip-10-172-254-217.us-west-2.compute.internal, executor driver): java.lang.ClassCastException: scala.Tuple3 cannot be cast to scala.Tuple2
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.next(SerDeUtil.scala:159)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.next(SerDeUtil.scala:150)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:150)
    at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:442)
    at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:703)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:479)
    at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2146)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:271)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2519)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2466)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2460)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2460)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1152)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1152)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1152)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2721)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2668)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2656)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2331)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2352)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2371)
    at org.apache.spark.api.python.PythonRDD$.collectPartitions(PythonRDD.scala:197)
    at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:217)
    at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
    at py4j.Gateway.invoke(Gateway.java:295)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:251)
    at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: scala.Tuple3 cannot be cast to scala.Tuple2
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.next(SerDeUtil.scala:159)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.next(SerDeUtil.scala:150)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:150)
    at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:442)
    at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:703)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:479)
    at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2146)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:271)

I also get deprecation warnings when doing the ROC plots:

image

dvgodoy commented 3 years ago

Hi,

Thanks for reporting this, but unfortunately Spark 3.0 is not supported by HandySpark. Spark 3.0 introduced a lot of changes, that's why these errors are showing up to you. Some of the functionalities of HandySpark are covered by Databricks' own Koalas package.

Best, Daniel