Open hjm-aws opened 1 year ago
Does this issue only can be repo via bert-large? The HLO is way too large to make sense of what it is trying to do.
I am comparing HloModule IrToHlo.38523
and HloModule IrToHlo.39056
I can see that from parameter p827.13201
to p852.34920
are different shape
and the output is slightly different between two hlo at the end
it seems to me that HloModule IrToHlo.39056
is doing some additional compuation that does not exist in the HloModule IrToHlo.38523
If I look at it a bit further what p827.13201
is being used in the computation, I found that
...
%p827.13201 = f32[32,1024]{1,0} parameter(827), metadata={op_type="xla__device_data" op_name="xla__device_data"}
...
%tuple.13225 = (f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=5*/f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=10*/f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, /*index=15*/f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, /*index=20*/f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, /*index=25*/f32[1024]{0}, f32[1024,1024]{1,0}, f32[]) tuple(f32[1024]{0} %p850.13224, f32[1024]{0} %p849.13223, f32[1024,1024]{1,0} %p848.13222, f32[1024]{0} %p847.13221, f32[1024]{0} %p846.13220, /*index=5*/f32[1024]{0} %p845.13219, f32[512,1024]{1,0} %p844.13218, f32[32,1024]{1,0} %p843.13217, f32[1024]{0} %p842.13216, f32[1024]{0} %p841.13215, /*index=10*/f32[1024,1024]{1,0} %p840.13214, f32[1024]{0} %p839.13213, f32[1024]{0} %p838.13212, f32[1024]{0} %p837.13211, f32[512,1024]{1,0} %p836.13210, /*index=15*/f32[32,1024]{1,0} %p835.13209, f32[1024]{0} %p834.13208, f32[1024]{0} %p833.13207, f32[1024,1024]{1,0} %p832.13206, f32[1024]{0} %p831.13205, /*index=20*/f32[1024]{0} %p830.13204, f32[1024]{0} %p829.13203, f32[512,1024]{1,0} %p828.13202, f32[32,1024]{1,0} %p827.13201, f32[1024]{0} %reshape.13200, /*index=25*/f32[1024]{0} %reshape.13175, f32[1024,1024]{1,0} %get-tuple-element.13143, f32[] %get-tuple-element.13010), metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}
...
%get-tuple-element.13249 = f32[32,1024]{1,0} get-tuple-element((f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=5*/f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, /*index=10*/f32[1024,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, /*index=15*/f32[32,1024]{1,0}, f32[1024]{0}, f32[1024]{0}, f32[1024,1024]{1,0}, f32[1024]{0}, /*index=20*/f32[1024]{0}, f32[1024]{0}, f32[512,1024]{1,0}, f32[32,1024]{1,0}, f32[1024]{0}, /*index=25*/f32[1024]{0}, f32[1024,1024]{1,0}, f32[]) %tuple.13225), index=23, metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}
....
%reduce-scatter.13258 = (f32[32]{0}, f32[32]{0}, f32[32,1024]{1,0}, f32[32]{0}, f32[32]{0}, /*index=5*/f32[32]{0}, f32[16,1024]{1,0}, f32[1,1024]{1,0}, f32[32]{0}, f32[32]{0}, /*index=10*/f32[32,1024]{1,0}, f32[32]{0}, f32[32]{0}, f32[32]{0}, f32[16,1024]{1,0}, /*index=15*/f32[1,1024]{1,0}, f32[32]{0}, f32[32]{0}, f32[32,1024]{1,0}, f32[32]{0}, /*index=20*/f32[32]{0}, f32[32]{0}, f32[16,1024]{1,0}, f32[1,1024]{1,0}, f32[32]{0}, /*index=25*/f32[32]{0}, f32[32,1024]{1,0}, f32[]) reduce-scatter(f32[1024]{0} %get-tuple-element.13226, f32[1024]{0} %get-tuple-element.13227, f32[1024,1024]{1,0} %get-tuple-element.13228, f32[1024]{0} %get-tuple-element.13229, f32[1024]{0} %get-tuple-element.13230, /*index=5*/f32[1024]{0} %get-tuple-element.13231, f32[512,1024]{1,0} %get-tuple-element.13232, f32[32,1024]{1,0} %get-tuple-element.13233, f32[1024]{0} %get-tuple-element.13234, f32[1024]{0} %get-tuple-element.13235, /*index=10*/f32[1024,1024]{1,0} %get-tuple-element.13236, f32[1024]{0} %get-tuple-element.13237, f32[1024]{0} %get-tuple-element.13238, f32[1024]{0} %get-tuple-element.13239, f32[512,1024]{1,0} %get-tuple-element.13240, /*index=15*/f32[32,1024]{1,0} %get-tuple-element.13241, f32[1024]{0} %get-tuple-element.13242, f32[1024]{0} %get-tuple-element.13243, f32[1024,1024]{1,0} %get-tuple-element.13244, f32[1024]{0} %get-tuple-element.13245, /*index=20*/f32[1024]{0} %get-tuple-element.13246, f32[1024]{0} %get-tuple-element.13247, f32[512,1024]{1,0} %get-tuple-element.13248, f32[32,1024]{1,0} %get-tuple-element.13249, f32[1024]{0} %get-tuple-element.13250, /*index=25*/f32[1024]{0} %get-tuple-element.13251, f32[1024,1024]{1,0} %get-tuple-element.13252, f32[] %get-tuple-element.13253), replica_groups={}, dimensions={0}, to_apply=%AddComputation.13254, metadata={op_type="xla__reduce_scatter" op_name="xla__reduce_scatter"}
the trick is you need to find the index of the tuple and see where it is being used. We might eventually find that place or find out this value is never being used. This graph is way too big so I can't trace down where it is eventually used.
Sorry Jack for the large graph!
It can be repro'ed on an even larger (5 billion params) model. I didn't try smaller model yet.
Added a smaller repro here: https://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1-layer-mlm.hlo. It has 3k lines per HLO module. It was generated with a new revision of https://gist.github.com/hjm-aws/d3b402535db6729b30678eab15faafda. I found that embedded FSDP (having FSDP wrapped module within another FSDP wrapped module) is a necessary condition to repro the problem.
There are a bunch of HLO in the raw file above, I picked two that looks somewhat similar HloModule IrToHlo.3047
and HloModule IrToHlo.2971
. HloModule IrToHlo.3047
is longer
The first big chunk of diff is from a bunch of constant 9
I am not sure what happened here, by default pytorch/xla only has constant 0
and 1
. @hjm-aws did you guys overwrite that behavior somewhere?
I checked that none of these constants are actually used in the computation. From the meta data it seems like they are from sum
%constant.2017 = s32[] constant(9), metadata={op_type="aten__sum" op_name="aten__sum"}
so maybe somewhere in the sum
we created a torch.tensor(9, device=xla_device)
. There are a total of 12 constant(9)
on the left side and 12 on the right side. It is just they appeared on different place of the graph.
I took a look at the remaining HLO and it seems like the most of the HLO is just doing tuple
and untuple
.
I have a theory, the way that mark_step
works is that it looked at all Live tensors
that needs to be synced and consutrct an IR graph out of it(post order travseral). What happened here might be that it is easier to accumulate uncleared tensor across multiple steps, and those tensors are slightly different for each runs. One common reason for this is there are additional stuff to be done after a single step, like logging, will pollute the graph.
one additional debugging step we can take is to dump all of the live tensors before each mark_step, I expect us to see a different set of live tensors every time which will explain the HLO difference. You can do that by
diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py
index 32e2fe72..b4584c21 100755
--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -942,6 +942,8 @@ def _run_step_closures():
def mark_step():
+ print(torch_xla._XLAC._xla_tensors_report(0, 'xla:0'))
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
print(
'torch_xla.core.xla_model::mark_step\n',
I would recommended setting a high threshold like 500(it is 0 in the diff above which will print every live tensor), so it will only print tensors with more than 500 ir node attached. This API will print the pending IR attached to every live tensor hence will be super long.
🐛 Bug
In FSDP backward pass, if we accumulate some callbacks and invoke them later in one batch, then different runs can result in slightly different computation graphs and cause recompilation.
To Reproduce
Steps to reproduce the behavior:
reduce_scatter_bucket_size_mb=20
, which will cause many callbacks to be accumulated and called together later.Expected behavior
Observe the generated HLO dump. An example dump is here:
https://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1101-16_49_12.hlo.ziphttps://github.com/hjm-aws/dump/blob/main/reduce_scatter_test-1-layer-mlm.hlo. We should see two unique graphs with more than 35K instructions. A smaller one is generated during the first run of the fwd/bwd pass, and a bigger one is generated during all subsequent run of fwd/bwd passes.In fact, we can see 4 such graphs. Two extra unique graphs are generated in the subsequent fwd/bwd runs. In the example dump above, see the following graphs:
HloModule IrToHlo.38523
HloModule IrToHlo.39056
HloModule IrToHlo.38547
They only have slight differences, like different order of the parameters.
Note
If we set
reduce_scatter_bucket_size_mb=0
such that every callback is called right after each reduce-scatter call, meaning no accumulation, then the issue will not happen.Environment
cc: @ronghanghu