pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Accumulated callbacks during backward pass can cause variation in computation graph. #4160

Open hjm-aws opened 1 year ago

hjm-aws commented 1 year ago

🐛 Bug

In FSDP backward pass, if we accumulate some callbacks and invoke them later in one batch, then different runs can result in slightly different computation graphs and cause recompilation.

To Reproduce

Steps to reproduce the behavior:

  1. Patch https://github.com/pytorch/xla/pull/4145 and https://github.com/tensorflow/tensorflow/pull/58377. These PRs enable all-gather and reduce-scatter coalescence. The changes relevant to this issue are https://github.com/pytorch/xla/pull/4145/files#diff-b027868386d87ff1458d88491a4247e8f9330deddf505969e03d0e45a88199f3L1085-R1139 and https://github.com/pytorch/xla/pull/4145/files#diff-0dd71c2a8206ce45c21e0a8fe3cfd95fcd5620b8a8f2329578bb95f41da86943R82-R121.
  2. Run https://gist.github.com/hjm-aws/d3b402535db6729b30678eab15faafda, which is a small example that applies FSDP on HuggingFace BERT-Large, and will exercise the relevant code changes listed in step 1 above. The key piece in the example is reduce_scatter_bucket_size_mb=20, which will cause many callbacks to be accumulated and called together later.

Expected behavior

Observe the generated HLO dump. An example dump is here: https://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1101-16_49_12.hlo.zip https://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1-layer-mlm.hlo. We should see two unique graphs with more than 35K instructions. A smaller one is generated during the first run of the fwd/bwd pass, and a bigger one is generated during all subsequent run of fwd/bwd passes.

In fact, we can see 4 such graphs. Two extra unique graphs are generated in the subsequent fwd/bwd runs. In the example dump above, see the following graphs: HloModule IrToHlo.38523 HloModule IrToHlo.39056 HloModule IrToHlo.38547

They only have slight differences, like different order of the parameters.

Note

If we set reduce_scatter_bucket_size_mb=0 such that every callback is called right after each reduce-scatter call, meaning no accumulation, then the issue will not happen.

Environment

cc: @ronghanghu

JackCaoG commented 1 year ago

Does this issue only can be repo via bert-large? The HLO is way too large to make sense of what it is trying to do.

I am comparing HloModule IrToHlo.38523 and HloModule IrToHlo.39056

I can see that from parameter p827.13201 to p852.34920 are different shape

image

and the output is slightly different between two hlo at the end image

it seems to me that HloModule IrToHlo.39056 is doing some additional compuation that does not exist in the HloModule IrToHlo.38523

JackCaoG commented 1 year ago

If I look at it a bit further what p827.13201 is being used in the computation, I found that

...
  %p827.13201 = f32[32,1024]{1,0} parameter(827), metadata={op_type="xla__device_data" op_name="xla__device_data"}
...
 %tuple.13225 = (f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=5*/f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=10*/f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, /*index=15*/f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, /*index=20*/f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, /*index=25*/f32[1024]{0}, f32[1024,1024]{1,0}, f32[]) tuple(f32[1024]{0} %p850.13224, f32[1024]{0} %p849.13223, f32[1024,1024]{1,0} %p848.13222, f32[1024]{0} %p847.13221, f32[1024]{0} %p846.13220, /*index=5*/f32[1024]{0} %p845.13219, f32[512,1024]{1,0} %p844.13218, f32[32,1024]{1,0} %p843.13217, f32[1024]{0} %p842.13216, f32[1024]{0} %p841.13215, /*index=10*/f32[1024,1024]{1,0} %p840.13214, f32[1024]{0} %p839.13213, f32[1024]{0} %p838.13212, f32[1024]{0} %p837.13211, f32[512,1024]{1,0} %p836.13210, /*index=15*/f32[32,1024]{1,0} %p835.13209, f32[1024]{0} %p834.13208, f32[1024]{0} %p833.13207, f32[1024,1024]{1,0} %p832.13206, f32[1024]{0} %p831.13205, /*index=20*/f32[1024]{0} %p830.13204, f32[1024]{0} %p829.13203, f32[512,1024]{1,0} %p828.13202, f32[32,1024]{1,0} %p827.13201, f32[1024]{0} %reshape.13200, /*index=25*/f32[1024]{0} %reshape.13175, f32[1024,1024]{1,0} %get-tuple-element.13143, f32[] %get-tuple-element.13010), metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}
...
%get-tuple-element.13249 = f32[32,1024]{1,0} get-tuple-element((f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=5*/f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=10*/f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, /*index=15*/f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, /*index=20*/f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, /*index=25*/f32[1024]{0}, f32[1024,1024]{1,0}, f32[]) %tuple.13225), index=23, metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}
....
  %reduce-scatter.13258 = (f32[32]{0}, f32[32]{0}, f32[32,1024]{1,0}, f32[32]{0}, f32[32]{0}, /*index=5*/f32[32]{0}, f32[16,1024]{1,0}, f32[1,1024]{1,0}, f32[32]{0}, f32[32]{0}, /*index=10*/f32[32,1024]{1,0}, f32[32]{0}, f32[32]{0}, f32[32]{0}, f32[16,1024]{1,0}, /*index=15*/f32[1,1024]{1,0}, f32[32]{0}, f32[32]{0}, f32[32,1024]{1,0}, f32[32]{0}, /*index=20*/f32[32]{0}, f32[32]{0}, f32[16,1024]{1,0}, f32[1,1024]{1,0}, f32[32]{0}, /*index=25*/f32[32]{0}, f32[32,1024]{1,0}, f32[]) reduce-scatter(f32[1024]{0} %get-tuple-element.13226, f32[1024]{0} %get-tuple-element.13227, f32[1024,1024]{1,0} %get-tuple-element.13228, f32[1024]{0} %get-tuple-element.13229, f32[1024]{0} %get-tuple-element.13230, /*index=5*/f32[1024]{0} %get-tuple-element.13231, f32[512,1024]{1,0} %get-tuple-element.13232, f32[32,1024]{1,0} %get-tuple-element.13233, f32[1024]{0} %get-tuple-element.13234, f32[1024]{0} %get-tuple-element.13235, /*index=10*/f32[1024,1024]{1,0} %get-tuple-element.13236, f32[1024]{0} %get-tuple-element.13237, f32[1024]{0} %get-tuple-element.13238, f32[1024]{0} %get-tuple-element.13239, f32[512,1024]{1,0} %get-tuple-element.13240, /*index=15*/f32[32,1024]{1,0} %get-tuple-element.13241, f32[1024]{0} %get-tuple-element.13242, f32[1024]{0} %get-tuple-element.13243, f32[1024,1024]{1,0} %get-tuple-element.13244, f32[1024]{0} %get-tuple-element.13245, /*index=20*/f32[1024]{0} %get-tuple-element.13246, f32[1024]{0} %get-tuple-element.13247, f32[512,1024]{1,0} %get-tuple-element.13248, f32[32,1024]{1,0} %get-tuple-element.13249, f32[1024]{0} %get-tuple-element.13250, /*index=25*/f32[1024]{0} %get-tuple-element.13251, f32[1024,1024]{1,0} %get-tuple-element.13252, f32[] %get-tuple-element.13253), replica_groups={}, dimensions={0}, to_apply=%AddComputation.13254, metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}

the trick is you need to find the index of the tuple and see where it is being used. We might eventually find that place or find out this value is never being used. This graph is way too big so I can't trace down where it is eventually used.

hjm-aws commented 1 year ago

Sorry Jack for the large graph!

It can be repro'ed on an even larger (5 billion params) model. I didn't try smaller model yet.

hjm-aws commented 1 year ago

Added a smaller repro here: https://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1-layer-mlm.hlo. It has 3k lines per HLO module. It was generated with a new revision of https://gist.github.com/hjm-aws/d3b402535db6729b30678eab15faafda. I found that embedded FSDP (having FSDP wrapped module within another FSDP wrapped module) is a necessary condition to repro the problem.

JackCaoG commented 1 year ago

There are a bunch of HLO in the raw file above, I picked two that looks somewhat similar HloModule IrToHlo.3047 and HloModule IrToHlo.2971. HloModule IrToHlo.3047 is longer

The first big chunk of diff is from a bunch of constant 9 image I am not sure what happened here, by default pytorch/xla only has constant 0 and 1. @hjm-aws did you guys overwrite that behavior somewhere?

I checked that none of these constants are actually used in the computation. From the meta data it seems like they are from sum

%constant.2017 = s32[] constant(9), metadata={op_type="aten__sum" op_name="aten__sum"}

so maybe somewhere in the sum we created a torch.tensor(9, device=xla_device). There are a total of 12 constant(9) on the left side and 12 on the right side. It is just they appeared on different place of the graph.

JackCaoG commented 1 year ago

I took a look at the remaining HLO and it seems like the most of the HLO is just doing tuple and untuple.

I have a theory, the way that mark_step works is that it looked at all Live tensors that needs to be synced and consutrct an IR graph out of it(post order travseral). What happened here might be that it is easier to accumulate uncleared tensor across multiple steps, and those tensors are slightly different for each runs. One common reason for this is there are additional stuff to be done after a single step, like logging, will pollute the graph.

one additional debugging step we can take is to dump all of the live tensors before each mark_step, I expect us to see a different set of live tensors every time which will explain the HLO difference. You can do that by

diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 32e2fe72..b4584c21 100755
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -942,6 +942,8 @@ def _run_step_closures():

 def mark_step():
+  print(torch_xla._XLAC._xla_tensors_report(0, 'xla:0'))
   if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
     print(
         'torch_xla.core.xla_model::mark_step\n',

I would recommended setting a high threshold like 500(it is 0 in the diff above which will print every live tensor), so it will only print tensors with more than 500 ir node attached. This API will print the pending IR attached to every live tensor hence will be super long.