pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

Graph dump to optimize #7693

Open sanjayss34 opened 1 month ago

sanjayss34 commented 1 month ago

🐛 Bug

My training run with PyTorch XLA is running very slowly (on TPU v3-64), and as suggested by the debug messages, I am submitting dumps of the execution graphs and soliciting feedback on how I can optimize the speed. Thanks for your support!

To Reproduce

Please see the attached graph dumps.

Expected behavior

Here are excerpts from the output in debug mode:

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step in parallel loader at step end
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 38e7631f2fac7433efcbec5f61d32ce3
Execution Analysis:   Number of Graph Inputs: 903
Execution Analysis:   Number of Graph Outputs: 1344
Execution Analysis: Python Frame Triggered Execution:  
Execution Analysis:   mark_step (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)
Execution Analysis:   next (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/distributed/parallel_loader.py:44)
Execution Analysis:   __next__ (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/distributed/parallel_loader.py:32)
Execution Analysis:   run_training (/home/sanjays1/prismatic-video-lms/prismatic/training/strategies/base_strategy.py:271)
Execution Analysis:   pretrain (/home/sanjays1/prismatic-video-lms/scripts/pretrain.py:295)
Execution Analysis:   wrapper_inner (/home/sanjays1/.local/lib/python3.10/site-packages/draccus/argparsing.py:203)
Execution Analysis:   xla_train (/home/sanjays1/prismatic-video-lms/scripts/xla_pretrain.py:29)
Execution Analysis:   __call__ (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:187)
Execution Analysis:   ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
pt-xla-profiler: ExecuteTime too slow: longest instance took 08m37s779ms801.697us. Please open a GitHub issue with the graph dump for our team to optimize.
pt-xla-profiler: ExecuteTime too slow: longest instance took 08m37s779ms801.697us. Please open a GitHub issue with the graph dump for our team to optimize.
pt-xla-profiler: ExecuteTime too slow: longest instance took 06m50s555ms470.186us. Please open a GitHub issue with the graph dump for our team to optimize.

and

Execution Analysis: ================================================================================                                                                                                                  
Execution Analysis: Execution Cause                                                                                                                                                                                   
Execution Analysis:   user mark_step                                                                                                                                                                                  
Execution Analysis: Graph Info:                                                                                                                                                                                       
Execution Analysis:   Graph Hash: 4852f46ddb5bf690e918e93631fcbeda                                                                                                                                                    
Execution Analysis:   Number of Graph Inputs: 1168                                                                                                                                                                    
Execution Analysis:   Number of Graph Outputs: 1540                                                                                                                                                                   
Execution Analysis: Python Frame Triggered Execution:                                                                                                                                                                 
Execution Analysis:   mark_step (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)                                                                                                 
Execution Analysis:   mark_step (/home/sanjays1/prismatic-video-lms/prismatic/overwatch/overwatch.py:207)                                                                                                             
Execution Analysis:   run_training (/home/sanjays1/prismatic-video-lms/prismatic/training/strategies/base_strategy.py:313)                                                                                            
Execution Analysis:   pretrain (/home/sanjays1/prismatic-video-lms/scripts/pretrain.py:295)                                                                                                                           
Execution Analysis:   wrapper_inner (/home/sanjays1/.local/lib/python3.10/site-packages/draccus/argparsing.py:203)                                                                                                    
Execution Analysis:   xla_train (/home/sanjays1/prismatic-video-lms/scripts/xla_pretrain.py:29)                                                                                                                       
Execution Analysis:   __call__ (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:187)                                                                                                   
Execution Analysis:   _thread_fn (/home/sanjays1/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:71)                                                                                                  
Execution Analysis:   ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
pt-xla-profiler: ExecuteTime too slow: longest instance took 08m37s779ms801.697us. Please open a GitHub issue with the graph dump for our team to optimize.
pt-xla-profiler: ExecuteTime too slow: longest instance took 06m50s555ms470.186us. Please open a GitHub issue with the graph dump for our team to optimize.
pt-xla-profiler: ExecuteTime too slow: longest instance took 06m50s555ms470.186us. Please open a GitHub issue with the graph dump for our team to optimize.

Environment

Additional context

Attaching graph1.txt and graph2.txt which have the graph dumps. graph1.txt graph2.txt

JackCaoG commented 1 month ago

Thanks for reaching out! Judging from your graph, it is the graph2 that takes 8 minutes to run, since it has ~47k IRs while graph1 only have 11k IRs. I think there are a couple things we can do here

  1. let's do a profile and we can take a look at which operationsis the bottleneck. 8minutes is really long even for a v3, we can confirm if the execution really takes that long in the profile. Please take a look at example in https://github.com/pytorch/xla/blob/master/examples/debug/train_resnet_profile.py for how to take profile. In your case you might want to set duration_ms to capture the graph 2. You can also checkout my video at https://youtu.be/40jYVhQHGEA?si=IvDt6kAa39cMLQJC and https://youtu.be/LK3A3vjo-KQ?si=a96c5XZqmfvI34FW for basic debugging tips.

  2. If we check the profile and still can't figure out what's wrong, we need to dump the HLO and ask XLA:TPU team to help. What you currently share is the IR. You can dump Hlos using XLA_FLAGS=--xla_dump_to=/tmp/xla_dump which will dump HLOs at different optimization stages.

cc @wonjoolee95 since you are offcall this week.

sanjayss34 commented 1 month ago

Hi Jack, thanks for your quick reply! Here's a screenshot of the profile I'm getting - not sure why no steps are being recorded, do you have any suggestions? I added xp.start_server() to my training code and then used capture_profile.py as described here (https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) to get the profile.

image
JackCaoG commented 1 month ago

Maybe try the programmatic_capture way to capture the profile? You can start the capture in step 2 for example(by that time the compilation should already be cached). One problem of the capture_profile .py is it is hard to capture the actual device execution, because the model might be doing the compilation, data-loading or whatever. Inserting the trace_detached right before the model step function after the initial steps will guarantee you to capture the actual device execution which is what we need.

https://github.com/pytorch/xla/blob/44f88a9d6135abe5cbb533485b40e19d11b88b23/examples/debug/train_resnet_profile.py#L24-L25 is a simple example but you are recommended to try it, running this example you should be able to see some trace. You can follow checkout my video above.

sanjayss34 commented 1 month ago

I ran the “capture_profile.py” script after a few training steps already happened. I did try the programmatic approach but got some errors. I can revisit it if really necessary, but do you have some idea of why else there are no training steps are being recorded by the “capture_profile.py” script?

JackCaoG commented 1 month ago

In the tools, if you click into the trace_viewer do you see anything? If you still don't see anything, by running https://github.com/pytorch/xla/blob/44f88a9d6135abe5cbb533485b40e19d11b88b23/examples/debug/train_resnet_profile.py#L24-L25 can you see anything form the profile? Trying to see if this is a tensorbaord version issue or the capturing issue. One easy way to tell is to check the xplane file size you captured, it is usually at least xx mbs when it captures the useful infos.

sanjayss34 commented 1 month ago

Thanks again for your reply! I got some potentially useful output after waiting more and trying the capture_profile.py once more:

image

This is what the trace_viewer looks like:

image

Also, the file size of *.xplane.pb is 82MB.

JackCaoG commented 1 month ago

The device is not being utilized very well, it seems to be idle for the most of the time, I am wondering why. In https://www.youtube.com/watch?v=40jYVhQHGEA I talk about how to understand a profile, if you can upload you xplane file somewhere I can download I can also try to take a quick look.

sanjayss34 commented 1 month ago

Thanks, here's the link to the xplane file: https://drive.google.com/file/d/1m73dIa4NO4IatB-v-LoMuUECUj2hEXIM/view?usp=sharing

JackCaoG commented 1 month ago

it seems like you are doing multi-process training on v3(not using the SPMD). There are a couple different problems in the profile I am seeing

  1. the gap between each execution is really big. It is hard for me to tell what the host is doing during the idle time.
  2. In the first execution SyncTensorsGraph.107141(2418538244349149121), all it does is xla__cross_replica_sum which pretty much is just all_reduce.
  3. In the second graph, SyncTensorsGraph.18860(15603542360290435973) it seems like it is just a adamw optimizer.

It seem like you either have multiple mark_step in a single step function, or you tried to access some data which trigger the exeuction(You can verified it with PT_XLA_DEBUG=1). At very least I was not able to see the actual model fwd and bwd in this profile, maybe increase the profiling duration will help.

I would suggest reduce the model to run on a single v3-8 and do profile again, this will make it easier to see other performance problems. I also recommend to just run the example above to see what a expected profile looks like and start from there.