pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

New all-gather API takes much more memory than 1.10 all-gather implementation via all-reduce #3510

Open ronghanghu opened 2 years ago

ronghanghu commented 2 years ago

🐛 Bug

In our internal tests, the new xm.all_gather API implemented in https://github.com/pytorch/xla/pull/3275 is shown to take significantly more memory to execute than the previous all-gather implementation via all_reduce in PyTorch XLA 1.10 (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615).

For example, as shown in the reproducing steps below, the new xm.all_gather API fails to all-gather a 512 MB tensor on v3-8 (that would result in only a 4 GB output tensor, much smaller than the 16 GB total TPU memory size). Meanwhile, the old xm.all_gather API in PyTorch XLA 1.10 can handle this case without a problem.

It is weird and unexpected why the new xm.all_gather API takes so much memory to execute. (A well-implemented all-gather should in principle take the same amount of memory as the output tensor size.) It is breaking many large-model use cases, e.g. when using a large number of sharded parameters or large tensors, or ZeRO (FSDP) in #3431.

One workaround is for the large-model users to manually revert to the previous all-gather implementation via all_reduce in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615. However, https://github.com/pytorch/xla/issues/3506 partially prevents this workaround since now xm.all_reduce cannot work with xm.reduce_scatter.

To Reproduce

  1. Allocate a v3-8 TPU VM from tpu-vm-pt-1.10 runtime and install 20220415 version of torch, torchvision, and torch_xla, while keeping 20220408 version of libtpu (since the newer 20220415 version was reported bad in https://github.com/pytorch/xla/issues/3502#issuecomment-1099777942).
    
    # torch, torchvision and torch_xla 20220415
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220415-cp38-cp38-linux_x86_64.whl
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220415-cp38-cp38-linux_x86_64.whl
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220415-cp38-cp38-linux_x86_64.whl

libtpu 20220408

sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl


2. Save the following content to a python file (e.g. `/home/ronghanghu/test_all_gather_only_mem.py` below).

import torch import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index): world_size = xm.xrt_world_size() device = xm.xla_device()

t1 = torch.ones(1024**3 // world_size, device=device)  # (4 GB // world_size) in float32, i.e. 512 MB on v3-8
xm.mark_step()
print(f"t1.sum(): {t1.sum()}, mem: {xm.get_memory_info(device)}")

t2 = xm.all_gather(t1).flatten()  # 4 GB in float32
del t1
xm.mark_step()
print(f"t2.sum(): {t2.sum()}, mem: {xm.get_memory_info(device)}")

if name == "main": xmp.spawn(_mp_fn, args=(), nprocs=8)


3. Run this file on the v3-8 TPU VM:

python3 /home/ronghanghu/test_all_gather_only_mem.py

It prints

... 2022-04-18 06:42:22.783996: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at tpu_execute_op.cc:266 : RESOURCE_EXHAUSTED: Attempting to reserve 12.00G at the bottom of memory. That was not possible. There are 10.98G free, 0B reserved, and 10.98G reservable.
...


4. If we revert `xm.all_gather` to the older version implemented via all_reduce (adding the snippet below to the code), then this example `/home/ronghanghu/test_all_gather_only_mem.py` can run without memory errors.

def old_all_gather(value, dim=0, groups=None): """ This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615 """ if dim < 0: dim = value.dim() + dim size = value.size(dim) padding = [0] (2 value.dim()) ordinal = xm.get_ordinal() if groups is None: left, right = ordinal, xm.xrt_world_size() - 1 - ordinal else: ordinals = dict() for g in groups: for i, x in enumerate(g): ordinals[x] = (i, len(g) - 1 - i) left, right = ordinals[ordinal] idx = value.dim() - 1 - dim padding[2 idx] = left size padding[2 idx + 1] = right size return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)

xm.all_gather = old_all_gather



## Expected behavior

The new `xm.all_gather` API should not take so much memory to execute. It should in principle take the same amount of memory as the output tensor size. Otherwise, it is preventing a lot of practical scaling applications such as the ZeRO optimizer from Microsoft.

## Environment

 - Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
 - torch_xla version: 20220415 nightly from `tpu-vm-pt-1.10` (see Step 1 above)

## Additional context

Based on the error message `Attempting to reserve 12.00G at the bottom of memory. That was not possible.`, it is still weird why the new `xm.all_gather` API cannot run in the example above -- since TPU v3 has 16 GB memory size, it should be able to run even if the new all-gather takes 12 GB?

cc: @JackCaoG 
JackCaoG commented 2 years ago

Thanks for reporting. I am able to repro this on my end. I am checking with Blake from the xla team, this seems like coming from all_gather is not optimized for the 1d tensor. If I change the input shape to 1024,1024**2 test above will pass. I am working with Blake to optimize this on the compiler end.

I am assuming you are not blocked right now since you can use the old all_gather for now, is this true?

ronghanghu commented 2 years ago

@JackCaoG Thanks for quickly reproducing this issue!

I am assuming you are not blocked right now since you can use the old all_gather for now, is this true?

Actually I'm blocked here, because I cannot revert to the old all_gather implementation via all_reduce at this moment -- so far all_reduce doesn't work with reduce_scatter as in https://github.com/pytorch/xla/issues/3506, while I often need to use all-gather and reduce-scatter together (e.g. for ZeRO implementation). And I cannot revert to the older torch_xla builds since I also need xm.optimization_barrier_.

If I change the input shape to 1024,1024**2 test above will pass. I am working with Blake to optimize this on the compiler end.

Yeah, this indeed looks weird (since in principle the two cases should take identical memory consumption). It would be great to optimize the 1d tensor case as well.

JackCaoG commented 2 years ago

@ronghanghu I want to get some context here. Does FSDP requires all_gather to happen on 1d tensor?

ronghanghu commented 2 years ago

Does FSDP requires all_gather to happen on 1d tensor?

I think I can also try hacking FSDP and get it to work on 2D tensors. I need to handle 1D tensors but can try reshaping it to 2D. However, it seems that naively changing the tensor to 2D as t1 = torch.ones(1, 1024**3 // world_size, device=device) still doesn't work here.

I wonder is there a specific requirement for the tensor shape for the current all-gather implementation to work efficiently (e.g. the first dimension should be as large as possible, such as 1024)? If so, I can try doing some padding in my case and reshape it to have 1024 as the first dim size

JackCaoG commented 2 years ago

Can you try padding to a 128 multiple and reshape to (x,128). I can also try to do that on pt/xla end so you don't need to worry about that.

ronghanghu commented 2 years ago

Thanks! I can first try on my end padding to (x,128) and see if it resolves my use case. It seems working in the example above, so I can try it on more cases.

As a long-term solution, I feel it's probably good to address it more upstream directly in the XLA ops in libtpu (I guess it's TPU-specific and doesn't affect e.g. CUDA all-gather) rather than patching it in torch_xla (but this is not urgent to me).

JackCaoG commented 2 years ago

@ronghanghu Can you give me some example that FSDP needs to do all_gather on a giant 1d tensor? XLA team is under the impression that this is unlikely to be a real use case, so if we can prove that this is a requirement it would be easier to push a fix.

ronghanghu commented 2 years ago

@JackCaoG For FSDP or ZeRO implementation, we often flatten and concatenate all the parameters in a layer (such as a transformer block) into a single 1D full_parameter vector. Then we shard this full_parameter vector to each device as sharded_parameter (1D). Then during training, we do all_gather on this 1D sharded_parameter tensor to rebuild the full_parameter. The example above in this issue reflects this use case. (This case of flattening all the parameters in a layer into a single 1D vector usually gives fewer API calls and better performance than separate sharding and gathering each parameter vector.)

ronghanghu commented 2 years ago

Can you try padding to a 128 multiple and reshape to (x,128). I can also try to do that on pt/xla end so you don't need to worry about that.

Update on this:

I'm debugging the all-gather case on v3-128 pod. I can also first try out switching to the old all-gather via all-reduce implementation after https://github.com/pytorch/xla/pull/3511 is merged.

ronghanghu commented 2 years ago

A further update on this: switching to the old all-gather via all-reduce implementation (using 20220419 builds containing https://github.com/pytorch/xla/pull/3511) didn't resolve the problem.

On v3-8, the alternative implementation of all-gather using padding + all-reduce takes much more memory to run compared to directly using all-gather after reshaping to (x, 128) and now breaks the 10B+ transformer model. This is true even if I change the all-reduce op in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L615 from

return all_reduce(REDUCE_SUM, F.pad(value, padding), groups=groups)

to in-place operation as

padded_value = F.pad(value, padding)
xm.all_reduce(
    xm.REDUCE_SUM, [padded_value], groups=groups)
return padded_value

(I haven't figured out whether the higher memory consumption of all-gather via all-reduce is due to the padding op or the all-reduce op yet. Still looking at it.)

JackCaoG commented 2 years ago

My understanding is that padding trick is only useful for the native all-gather.

ronghanghu commented 2 years ago

My understanding is that padding trick is only useful for the native all-gather.

(The "padding" in the old all-gather via all-reduce refers to padding the tensor to all-gather output shape and then all-reducing it as in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L615, not padding to a multiple of 128)

So my current observation is that old all-gather via all-reduce still takes a lot more memory on v3-8 than the new all-gather with reshaping to (x, 128). But the latter new all-gather with reshaping to (x, 128) still doesn't work well on TPU pod.

JackCaoG commented 2 years ago

Ah, I guess it is a good thing then? Using the new all-gather + padding seems to work better.

ronghanghu commented 2 years ago

Ah, I guess it is a good thing then? Using the new all-gather + padding seems to work better.

Yeah, the new all-gather + padding works better on v3-8.

But the new all-gather's memory consumption seems to grow with the number of TPU cores (for the same gather output tensor size) and I cannot apply it on v3-128 for now. I'm still looking into it. On the other hand, the total memory consumption of all-reduce doesn't grow with the number of TPU cores.

ronghanghu commented 2 years ago

(In our case, getting the model to run on v3-8 is not very useful by itself, and we need to get it to run on pod such v3-128 and v3-256 for actual model experimentation, which I'm trying to debug now. It seems that some of the behavior of v3-8 doesn't always hold true on pod.)

JackCaoG commented 2 years ago

@ronghanghu Can you provide a small repro? I can check with the XLA team.

ronghanghu commented 2 years ago

@ronghanghu Can you provide a small repro? I can check with the XLA team.

Yeah, this is exactly what I'm working on now. Thanks for your help :)

ronghanghu commented 2 years ago

@ronghanghu Can you provide a small repro? I can check with the XLA team.

Following up on this: my TPU OOM issue on v3-128 is actually due to fusion by the XLA compiler, not the all-gather op itself. It seems that the XLA fusion behavior of the same/similar program on v3-128 is sometimes different from v3-8.

If I insert xm.mark_step at a few places in my program (only for debugging purposes here as many mark-step calls significantly slow down the execution), then the TPU OOM doesn't happen on v3-128 either. The trick of first reshaping to (x, 128) and then doing all-gather works well on the v3-128 pod in this case.

I'm trying to see if I can prevent this XLA fusion by aggressively inserting xm.optimization_barrier_ for our real use cases. So far the xm.optimization_barrier_ API has been a GREAT feature for us. It saves a lot of memory (by preventing fusion), accelerates the compilation, stops the compiler from going rogue, and almost never decreases the execution speed in our cases.

Meanwhile, it will still be great if the XLA team can optimize the all-gather op on 1d tensors, and get all-gather & reduce-scatter work together with pinned layout on both.

ronghanghu commented 2 years ago

It seems very hard to prevent some XLA fusions and I still cannot get the same TPU memory usage with xm. optimization_barrier_ as using explicit mark-step xm.mark_step. I'll try and see if I can create a small repro example on this XLA fusion problem and submit a new issue.

JackCaoG commented 2 years ago

I chatted with Blake, there is a Flag on the XLA end that will optimize the 1d vector all-gather. However it is currently being turn off in v3 TPU(it is on for v4) due to some nan issue. XLA team is looking into this bug but it is unsafe to turn on this optimization until they root cause the nan.

In the meantime let's use the reshape trick to unblock the fsdp if that is OK with you :D.

ronghanghu commented 2 years ago

I chatted with Blake, there is a Flag on the XLA end that will optimize the 1d vector all-gather. However it is currently being turn off in v3 TPU(it is on for v4) due to some nan issue. XLA team is looking into this bug but it is unsafe to turn on this optimization until they root cause the nan.

In the meantime let's use the reshape trick to unblock the fsdp if that is OK with you :D.

Sounds good, and thanks for the update! I'll first use the reshape (x, 128) trick in my FSDP use case.