Open ronghanghu opened 2 years ago
Thanks for reporting. I am able to repro this on my end. I am checking with Blake from the xla team, this seems like coming from all_gather
is not optimized for the 1d tensor. If I change the input shape to 1024,1024**2
test above will pass. I am working with Blake to optimize this on the compiler end.
I am assuming you are not blocked right now since you can use the old all_gather
for now, is this true?
@JackCaoG Thanks for quickly reproducing this issue!
I am assuming you are not blocked right now since you can use the old
all_gather
for now, is this true?
Actually I'm blocked here, because I cannot revert to the old all_gather
implementation via all_reduce
at this moment -- so far all_reduce
doesn't work with reduce_scatter
as in https://github.com/pytorch/xla/issues/3506, while I often need to use all-gather and reduce-scatter together (e.g. for ZeRO implementation). And I cannot revert to the older torch_xla builds since I also need xm.optimization_barrier_
.
If I change the input shape to
1024,1024**2
test above will pass. I am working with Blake to optimize this on the compiler end.
Yeah, this indeed looks weird (since in principle the two cases should take identical memory consumption). It would be great to optimize the 1d tensor case as well.
@ronghanghu I want to get some context here. Does FSDP requires all_gather to happen on 1d tensor?
Does FSDP requires all_gather to happen on 1d tensor?
I think I can also try hacking FSDP and get it to work on 2D tensors. I need to handle 1D tensors but can try reshaping it to 2D. However, it seems that naively changing the tensor to 2D as t1 = torch.ones(1, 1024**3 // world_size, device=device)
still doesn't work here.
I wonder is there a specific requirement for the tensor shape for the current all-gather implementation to work efficiently (e.g. the first dimension should be as large as possible, such as 1024)? If so, I can try doing some padding in my case and reshape it to have 1024 as the first dim size
Can you try padding to a 128 multiple and reshape to (x,128). I can also try to do that on pt/xla end so you don't need to worry about that.
Thanks! I can first try on my end padding to (x,128)
and see if it resolves my use case. It seems working in the example above, so I can try it on more cases.
As a long-term solution, I feel it's probably good to address it more upstream directly in the XLA ops in libtpu
(I guess it's TPU-specific and doesn't affect e.g. CUDA all-gather) rather than patching it in torch_xla
(but this is not urgent to me).
@ronghanghu Can you give me some example that FSDP needs to do all_gather
on a giant 1d tensor? XLA team is under the impression that this is unlikely to be a real use case, so if we can prove that this is a requirement it would be easier to push a fix.
@JackCaoG For FSDP or ZeRO implementation, we often flatten and concatenate all the parameters in a layer (such as a transformer block) into a single 1D full_parameter vector. Then we shard this full_parameter vector to each device as sharded_parameter (1D). Then during training, we do all_gather on this 1D sharded_parameter tensor to rebuild the full_parameter. The example above in this issue reflects this use case. (This case of flattening all the parameters in a layer into a single 1D vector usually gives fewer API calls and better performance than separate sharding and gathering each parameter vector.)
all_gather
in https://github.com/pytorch/xla/blob/1664252f66c29f93dfb00eb57d70e90ff44b07ef/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py#L888 in #3431.Can you try padding to a 128 multiple and reshape to (x,128). I can also try to do that on pt/xla end so you don't need to worry about that.
Update on this:
I'm debugging the all-gather case on v3-128 pod. I can also first try out switching to the old all-gather via all-reduce implementation after https://github.com/pytorch/xla/pull/3511 is merged.
A further update on this: switching to the old all-gather via all-reduce implementation (using 20220419 builds containing https://github.com/pytorch/xla/pull/3511) didn't resolve the problem.
On v3-8, the alternative implementation of all-gather using padding + all-reduce takes much more memory to run compared to directly using all-gather after reshaping to (x, 128) and now breaks the 10B+ transformer model. This is true even if I change the all-reduce op in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L615 from
return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)
to in-place operation as
padded_value = F.pad(value, padding)
xm.all_reduce(
xm.REDUCE_SUM, [padded_value], groups=groups)
return padded_value
(I haven't figured out whether the higher memory consumption of all-gather via all-reduce is due to the padding op or the all-reduce op yet. Still looking at it.)
My understanding is that padding
trick is only useful for the native all-gather
.
My understanding is that
padding
trick is only useful for the nativeall-gather
.
(The "padding" in the old all-gather via all-reduce refers to padding the tensor to all-gather output shape and then all-reducing it as in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L615, not padding to a multiple of 128)
So my current observation is that old all-gather via all-reduce still takes a lot more memory on v3-8 than the new all-gather with reshaping to (x, 128). But the latter new all-gather with reshaping to (x, 128) still doesn't work well on TPU pod.
Ah, I guess it is a good thing then? Using the new all-gather
+ padding seems to work better.
Ah, I guess it is a good thing then? Using the new
all-gather
+ padding seems to work better.
Yeah, the new all-gather + padding works better on v3-8.
But the new all-gather's memory consumption seems to grow with the number of TPU cores (for the same gather output tensor size) and I cannot apply it on v3-128 for now. I'm still looking into it. On the other hand, the total memory consumption of all-reduce doesn't grow with the number of TPU cores.
(In our case, getting the model to run on v3-8 is not very useful by itself, and we need to get it to run on pod such v3-128 and v3-256 for actual model experimentation, which I'm trying to debug now. It seems that some of the behavior of v3-8 doesn't always hold true on pod.)
@ronghanghu Can you provide a small repro? I can check with the XLA team.
@ronghanghu Can you provide a small repro? I can check with the XLA team.
Yeah, this is exactly what I'm working on now. Thanks for your help :)
@ronghanghu Can you provide a small repro? I can check with the XLA team.
Following up on this: my TPU OOM issue on v3-128 is actually due to fusion by the XLA compiler, not the all-gather op itself. It seems that the XLA fusion behavior of the same/similar program on v3-128 is sometimes different from v3-8.
If I insert xm.mark_step
at a few places in my program (only for debugging purposes here as many mark-step calls significantly slow down the execution), then the TPU OOM doesn't happen on v3-128 either. The trick of first reshaping to (x, 128) and then doing all-gather works well on the v3-128 pod in this case.
I'm trying to see if I can prevent this XLA fusion by aggressively inserting xm.optimization_barrier_
for our real use cases. So far the xm.optimization_barrier_
API has been a GREAT feature for us. It saves a lot of memory (by preventing fusion), accelerates the compilation, stops the compiler from going rogue, and almost never decreases the execution speed in our cases.
Meanwhile, it will still be great if the XLA team can optimize the all-gather op on 1d tensors, and get all-gather & reduce-scatter work together with pinned layout on both.
It seems very hard to prevent some XLA fusions and I still cannot get the same TPU memory usage with xm. optimization_barrier_
as using explicit mark-step xm.mark_step
. I'll try and see if I can create a small repro example on this XLA fusion problem and submit a new issue.
I chatted with Blake, there is a Flag on the XLA end that will optimize the 1d vector all-gather. However it is currently being turn off in v3 TPU(it is on for v4) due to some nan
issue. XLA team is looking into this bug but it is unsafe to turn on this optimization until they root cause the nan
.
In the meantime let's use the reshape trick to unblock the fsdp if that is OK with you :D.
I chatted with Blake, there is a Flag on the XLA end that will optimize the 1d vector all-gather. However it is currently being turn off in v3 TPU(it is on for v4) due to some
nan
issue. XLA team is looking into this bug but it is unsafe to turn on this optimization until they root cause thenan
.In the meantime let's use the reshape trick to unblock the fsdp if that is OK with you :D.
Sounds good, and thanks for the update! I'll first use the reshape (x, 128) trick in my FSDP use case.
🐛 Bug
In our internal tests, the new
xm.all_gather
API implemented in https://github.com/pytorch/xla/pull/3275 is shown to take significantly more memory to execute than the previous all-gather implementation via all_reduce in PyTorch XLA 1.10 (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615).For example, as shown in the reproducing steps below, the new
xm.all_gather
API fails to all-gather a 512 MB tensor on v3-8 (that would result in only a 4 GB output tensor, much smaller than the 16 GB total TPU memory size). Meanwhile, the oldxm.all_gather
API in PyTorch XLA 1.10 can handle this case without a problem.It is weird and unexpected why the new
xm.all_gather
API takes so much memory to execute. (A well-implemented all-gather should in principle take the same amount of memory as the output tensor size.) It is breaking many large-model use cases, e.g. when using a large number of sharded parameters or large tensors, or ZeRO (FSDP) in #3431.One workaround is for the large-model users to manually revert to the previous all-gather implementation via all_reduce in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615. However, https://github.com/pytorch/xla/issues/3506 partially prevents this workaround since now
xm.all_reduce
cannot work withxm.reduce_scatter
.To Reproduce
tpu-vm-pt-1.10
runtime and install20220415
version oftorch
,torchvision
, andtorch_xla
, while keeping20220408
version of libtpu (since the newer20220415
version was reported bad in https://github.com/pytorch/xla/issues/3502#issuecomment-1099777942).libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
import torch import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp
def _mp_fn(index): world_size = xm.xrt_world_size() device = xm.xla_device()
if name == "main": xmp.spawn(_mp_fn, args=(), nprocs=8)
python3 /home/ronghanghu/test_all_gather_only_mem.py
... 2022-04-18 06:42:22.783996: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at tpu_execute_op.cc:266 : RESOURCE_EXHAUSTED: Attempting to reserve 12.00G at the bottom of memory. That was not possible. There are 10.98G free, 0B reserved, and 10.98G reservable.
...
def old_all_gather(value, dim=0, groups=None): """ This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615 """ if dim < 0: dim = value.dim() + dim size = value.size(dim) padding = [0] (2 value.dim()) ordinal = xm.get_ordinal() if groups is None: left, right = ordinal, xm.xrt_world_size() - 1 - ordinal else: ordinals = dict() for g in groups: for i, x in enumerate(g): ordinals[x] = (i, len(g) - 1 - i) left, right = ordinals[ordinal] idx = value.dim() - 1 - dim padding[2 idx] = left size padding[2 idx + 1] = right size return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)
xm.all_gather = old_all_gather