Open sanjayss34 opened 3 months ago
Thanks for reaching out! Judging from your graph, it is the graph2 that takes 8 minutes to run, since it has ~47k IRs while graph1 only have 11k IRs. I think there are a couple things we can do here
let's do a profile and we can take a look at which operationsis the bottleneck. 8minutes is really long even for a v3, we can confirm if the execution really takes that long in the profile.
Please take a look at example in https://github.com/pytorch/xla/blob/master/examples/debug/train_resnet_profile.py for how to take profile. In your case you might want to set duration_ms
to capture the graph 2. You can also checkout my video at
https://youtu.be/40jYVhQHGEA?si=IvDt6kAa39cMLQJC and https://youtu.be/LK3A3vjo-KQ?si=a96c5XZqmfvI34FW for basic debugging tips.
If we check the profile and still can't figure out what's wrong, we need to dump the HLO and ask XLA:TPU team to help. What you currently share is the IR. You can dump Hlos using XLA_FLAGS=--xla_dump_to=/tmp/xla_dump
which will dump HLOs at different optimization stages.
cc @wonjoolee95 since you are offcall this week.
Hi Jack, thanks for your quick reply! Here's a screenshot of the profile I'm getting - not sure why no steps are being recorded, do you have any suggestions? I added xp.start_server() to my training code and then used capture_profile.py as described here (https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) to get the profile.
Maybe try the programmatic_capture way to capture the profile? You can start the capture in step 2 for example(by that time the compilation should already be cached). One problem of the capture_profile .py
is it is hard to capture the actual device execution, because the model might be doing the compilation, data-loading or whatever. Inserting the trace_detached
right before the model step function after the initial steps will guarantee you to capture the actual device execution which is what we need.
https://github.com/pytorch/xla/blob/44f88a9d6135abe5cbb533485b40e19d11b88b23/examples/debug/train_resnet_profile.py#L24-L25 is a simple example but you are recommended to try it, running this example you should be able to see some trace. You can follow checkout my video above.
I ran the “capture_profile.py” script after a few training steps already happened. I did try the programmatic approach but got some errors. I can revisit it if really necessary, but do you have some idea of why else there are no training steps are being recorded by the “capture_profile.py” script?
In the tools
, if you click into the trace_viewer
do you see anything? If you still don't see anything, by running https://github.com/pytorch/xla/blob/44f88a9d6135abe5cbb533485b40e19d11b88b23/examples/debug/train_resnet_profile.py#L24-L25 can you see anything form the profile? Trying to see if this is a tensorbaord version issue or the capturing issue. One easy way to tell is to check the xplane file size you captured, it is usually at least xx mbs
when it captures the useful infos.
Thanks again for your reply! I got some potentially useful output after waiting more and trying the capture_profile.py
once more:
This is what the trace_viewer
looks like:
Also, the file size of *.xplane.pb
is 82MB.
The device is not being utilized very well, it seems to be idle for the most of the time, I am wondering why. In https://www.youtube.com/watch?v=40jYVhQHGEA I talk about how to understand a profile, if you can upload you xplane file somewhere I can download I can also try to take a quick look.
Thanks, here's the link to the xplane file: https://drive.google.com/file/d/1m73dIa4NO4IatB-v-LoMuUECUj2hEXIM/view?usp=sharing
it seems like you are doing multi-process training on v3(not using the SPMD). There are a couple different problems in the profile I am seeing
SyncTensorsGraph.107141(2418538244349149121)
, all it does is xla__cross_replica_sum
which pretty much is just all_reduce
. SyncTensorsGraph.18860(15603542360290435973)
it seems like it is just a adamw
optimizer.It seem like you either have multiple mark_step
in a single step function, or you tried to access some data which trigger the exeuction(You can verified it with PT_XLA_DEBUG=1
). At very least I was not able to see the actual model fwd and bwd in this profile, maybe increase the profiling duration will help.
I would suggest reduce the model to run on a single v3-8
and do profile again, this will make it easier to see other performance problems. I also recommend to just run the example above to see what a expected profile looks like and start from there.
🐛 Bug
My training run with PyTorch XLA is running very slowly (on TPU v3-64), and as suggested by the debug messages, I am submitting dumps of the execution graphs and soliciting feedback on how I can optimize the speed. Thanks for your support!
To Reproduce
Please see the attached graph dumps.
Expected behavior
Here are excerpts from the output in debug mode:
and
Environment
Additional context
Attaching graph1.txt and graph2.txt which have the graph dumps. graph1.txt graph2.txt