intel-analytics / ipex-llm

Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Mixtral, Gemma, Phi, MiniCPM, Qwen-VL, MiniCPM-V, etc.) on Intel XPU (e.g., local PC with iGPU and NPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, vLLM, GraphRAG, DeepSpeed, Axolotl, etc
Apache License 2.0
6.71k stars 1.26k forks source link

RuntimeError: Socket Timeout when running pyspark pytorch Orca Estimator training on multiple nodes with big models on K8S #8151

Open jenniew opened 1 year ago

jenniew commented 1 year ago

Running Pytorch Pyspark Estimator training on multiple nodes on Kubernetes with big models sometimes got RuntimeError: Socket Timeout- when workers init_process_group. The trace back is as below: Exception: Traceback (most recent call last): File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 604, in main process() File "/opt/spark/python/lib/pyspark.zip/pyspark/worker.py", line 594, in process out_iter = func(split_index, iterator) File "/opt/spark/python/lib/pyspark.zip/pyspark/rdd.py", line 2863, in func File "/opt/bigdl-2.3.0-SNAPSHOT/python/bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py", line 370, in File "/opt/bigdl-2.3.0-SNAPSHOT/python/bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py", line 367, in transform_func File "./bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/pytorch_pyspark_worker.py", line 93, in init self.setup_distributed(self.mode, cluster_info, driver_ip, driver_tcp_store_port) File "./bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/pytorch_pyspark_worker.py", line 116, in setup_distributed self.setup_torch_distribute(tcp_store_host=driver_ip, File "./bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/core/lifecycle.py", line 29, in setup_torch_distribute self._init_torch_ddp(tcp_store_host, tcp_store_port, world_rank, File "./bigdl-orca-spark_3.1.3-2.3.0-SNAPSHOT-python-api.zip/bigdl/orca/learn/pytorch/core/lifecycle.py", line 72, in _init_torch_ddp dist.init_process_group( File "/opt/spark/work-dir/lora2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 761, in init_process_group default_pg = _new_process_group_helper( File "/opt/spark/work-dir/lora2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 862, in _new_process_group_helper pg = ProcessGroupGloo(prefix_store, group_rank, group_size, timeout=timeout) RuntimeError: Socket Timeout

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:517)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:652)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:635)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:470)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
jenniew commented 1 year ago

There are reasons for the issue:

  1. Every worker needs to load model first, then init_process_group. If the model is big, the modeling loading time is long and the loading speed is different on each worker, there are a lot time difference when init_process_group. i.e, worker 1 already started init_process_group, but other workers were still loading models. If the time diffrence is long, and the client tcpstore timeout is current set to default value(300s). So it may timeout if the process group was not established, and no response from server tcpstore. So we can make client tcpstore timeout longer to let the process group created.

Related PR: https://github.com/intel-analytics/BigDL/pull/8152

  1. Even if tcpstore timeout was enough, we still got this error sometimes. I checked the rank info sent by each worker when init_process_group. And the rank on some workers were the same. Checked the get_rank(), the algorithm depends on the cluster info collected by the last job. But the two jobs may have different partition placement( partitions on different executors, different number of partitions on same executor), so the rank can be duplicated or missing some values. To work around this issue, we may set spark.task.cpus=cores/executor to make each worker rdd partition on each executor to guarantee the same partition placement in two jobs. And this issue only occurs on k8s.
jenniew commented 1 year ago

For the issue 2, change the rank number to partition id, so the rank number on each worker is distinct and continuous, and the init_process_group can run successfully. And we don't need to do the first job to get cluster info.

Related PR: https://github.com/intel-analytics/BigDL/pull/8188

truongnx15 commented 1 year ago

I am having the same issue running pyspark orca on a yarn cluster. The issue occured in both stable and nightly build, and for both spark2 and spark3 versions.

Some of my runs using small number of nodes ( <= 30) were successfully, however, upon scaling to 50 nodes or more, there was no successful run yet.

This is my code to init orca context init_orca_context(cluster_mode='yarn-client', num_nodes=100, cores=20, memory="16g", driver_memory="20g", driver_cores=8, extra_python_lib="model.py,config.zip,data.zip,loss.zip", conf={"spark.task.cpus": "20", "spark.dynamicAllocation.enabled": "false", "spark.driver.maxResultSize": "4g"})

For small number of nodes (<= 30): if runs were successful, the setup_distributed funcion in pytorch_pyspark_worker finished in 1 or 2 seconds. When the issue happened, the function stuck until the timeout (default is 30 mins).

@jenniew do you have other thoughts about what could be the reason? if any info and logs needed, I can help get to investigate this issue.

jenniew commented 1 year ago

@truongnx15 What is your data source for Orca Estimator training? Is RDD/Dataframe or callable of pytorch data loader? Can you provide your error logs? We'll try to reproduce your issue to see how to fix it.

truongnx15 commented 1 year ago

We're using Dataframe as the data source. My environment is python 3.7 with orca for Spark3 nightly build installed on 10 June 2023.

I ran with 50 num_nodes and 20 cores per node.

The error occured 30 minutes after the 50 logs like this for 50 partitions:

[partition = 1, ip = 10.197.84.206] [2023-06-13 21:45:39] INFO cluster is: [LIST OF THE IP:PORT FOR THE CLUSTER] [partition = 1, ip = 10.197.84.206] [2023-06-13 21:45:39] INFO Connected log server on 10.197.74.131:34785

below is the stacktrace:

`Traceback (most recent call last): File "train_spark.py", line 178, in train() File "train_spark.py", line 160, in train epochs=int(config["model"]["num_epochs"]) File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py", line 337, in fit lambda iter: transform_func(iter, init_params, params)).collect() File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/pyspark/rdd.py", line 949, in collect sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/py4j/java_gateway.py", line 1305, in call answer, self.gateway_client, self.target_id, self.name) File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/pyspark/sql/utils.py", line 111, in deco return f(*a, **kw) File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/py4j/protocol.py", line 328, in get_return_value format(target_id, ".", name), value) py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe. : org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(10, 40) finished unsuccessfully. org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/pyspark.zip/pyspark/worker.py", line 604, in main process() File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/pyspark.zip/pyspark/worker.py", line 594, in process out_iter = func(split_index, iterator) File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/pyspark/rdd.py", line 2916, in pipeline_func File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/pyspark/rdd.py", line 2863, in func File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py", line 337, in File "/data/home/auser/.conda/envs/graph_emb_spark3_nightly/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/pytorch_pyspark_estimator.py", line 330, in transform_func File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/pytorch_pyspark_worker.py", line 93, in init self.setup_distributed(self.mode, cluster_info, driver_ip, driver_tcp_store_port) File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/pytorch_pyspark_worker.py", line 116, in setup_distributed world_size=self.size) File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/core/lifecycle.py", line 30, in setup_torch_distribute world_size) File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/bigdl/orca/learn/pytorch/core/lifecycle.py", line 77, in _init_torch_ddp world_size=world_size) File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 769, in init_process_group timeout=timeout, File "/data7/yarn/nm/usercache/auser/appcache/application_1685365996012_140245/container_e08_1685365996012_140245_01_000039/python_env/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 862, in _new_process_group_helper pg = ProcessGroupGloo(prefix_store, group_rank, group_size, timeout=timeout) RuntimeError: Socket Timeout

at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:517)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:652)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:635)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:470)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1030)
at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2236)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:131)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:498)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:501)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)`
jenniew commented 1 year ago

@truongnx15, thank you for your information. Would you mind to attach the whole log file? What is the approximate size of your train data?

truongnx15 commented 1 year ago

My dataframe is about 5M rows, 5 columns (schema is 3 long, 1 float, 1 array[long] with fixed 2 elements). Below is the full log files (of a different run on a different cluster using python 3.8) spark3_50nodes.log

jenniew commented 1 year ago

@truongnx15 We create similar data and run with 50-80 executor with your spark configuration, but cannot reproduce your issue.

truongnx15 commented 1 year ago

Thank you for having a look at it. I am still having the same issue. It even failed for small number of nodes sometimes, so I guess it's smth related to network or my yarn cluster setup but couldn't figure out yet.