NVIDIA / spark-rapids-examples

A repo for all spark examples using Rapids Accelerator including ETL, ML/DL, etc.
Apache License 2.0
129 stars 51 forks source link

Stage failed errror when collecting results in Spark Deep Learning inference notebook in Windows OS #328

Open janaccnt opened 1 year ago

janaccnt commented 1 year ago

Discussed in https://github.com/NVIDIA/spark-rapids-examples/discussions/327

Originally posted by **janaccnt** November 4, 2023 I try several [Spark Deep Learning inference notebooks](https://github.com/NVIDIA/spark-rapids-examples/tree/branch-23.06/examples/ML%2BDL-Examples/Spark-DL/dl_inference) on **Windows**. I run Spark in standalone mode with 1 worker with 12 cores (both driver-memory and executor-memory are set to 8G). I always get the same error when applying the deep learning model to the test dataset. For example, in the [MNIST image classification notebook](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.06/examples/ML%2BDL-Examples/Spark-DL/dl_inference/tensorflow/image_classification.ipynb), when running this cell: ``` %%time # first pass caches model/fn preds = df.withColumn("preds", mnist(struct(df.columns))).collect() ``` I always get this error: > Py4JJavaError: An error occurred while calling o890.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 9 in stage 6.0 failed 4 times, most recent failure: Lost task 9.3 in stage 6.0 (TID 69) (192.168.x.x executor 0): java.net.SocketException: Connection reset at java.net.SocketInputStream.read(SocketInputStream.java:210) at java.net.SocketInputStream.read(SocketInputStream.java:141) at java.io.BufferedInputStream.fill(BufferedInputStream.java:246) at java.io.BufferedInputStream.read(BufferedInputStream.java:265) at java.io.DataInputStream.readInt(DataInputStream.java:387) at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:105) at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364) at org.apache.spark.rdd.RDD.iterator(RDD.scala:328) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93) at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161) at org.apache.spark.scheduler.Task.run(Task.scala:141) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463) at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1046) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:407) at org.apache.spark.rdd.RDD.collect(RDD.scala:1045) at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:448) at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:4160) at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334) at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546) at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332) at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:4157) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) at py4j.ClientServerConnection.run(ClientServerConnection.java:106) at java.lang.Thread.run(Thread.java:748) Caused by: java.net.SocketException: Connection reset at java.net.SocketInputStream.read(SocketInputStream.java:210) at java.net.SocketInputStream.read(SocketInputStream.java:141) at java.io.BufferedInputStream.fill(BufferedInputStream.java:246) at java.io.BufferedInputStream.read(BufferedInputStream.java:265) at java.io.DataInputStream.readInt(DataInputStream.java:387) at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:105) at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525) at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364) at org.apache.spark.rdd.RDD.iterator(RDD.scala:328) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93) at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161) at org.apache.spark.scheduler.Task.run(Task.scala:141) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more I created and configured sparksession as below: ``` import findspark findspark.init() import pyspark from pyspark.sql import SparkSession from pyspark.conf import SparkConf spark = SparkSession \ .builder \ .appName("spark-dl-inference") \ .master("spark://192.168.x.x:7077") \ .config("spark.executor.memory", "8g") \ .config("spark.driver.memory", "8g") \ .config("spark.python.worker.reuse",True) \ .getOrCreate() ``` This is the web ui of the master: ![master](https://github.com/NVIDIA/spark-rapids-examples/assets/68731664/34fa01d1-029a-449d-941e-e707d9ecfd81) This is the web ui of the worker: ![worker](https://github.com/NVIDIA/spark-rapids-examples/assets/68731664/ab16ddd1-c5c6-4903-9b8a-e2f3f408e8aa) My window laptop has 24gb RAM. I've also reduced the model size (just 100k parameters) and the test dataset (only 70 images). I've tried increase the executor.memory to 12g but still got the same error. The code works perfectly in Google Colab (Spark run in Local mode with 1 core). I've try runing Spark in both Local and Standalone mode in Windows. I've also tried changing the number of core to 1 but still got the same error. So I guess the error might be related to the Windows OS which I'm using.
GaryShen2008 commented 3 months ago

Hi @leewyang, Not sure if we can close this issue.