pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Spmd pre-training llama2 multi-machine training so slow? #6778

Open mars1248 opened 3 months ago

mars1248 commented 3 months ago

spmd has a normal training speed using eight blocks on a single machine, but the communication overhead increases rapidly in the case of multiple machines device is: gpu:A100 8 2 spmd strategy is:

for name, param in model.named_parameters():
    shape = (num_devices,) + (1,) * (len(param.shape) - 1)
    mesh = xs.Mesh(device_ids, shape)
    xs.mark_sharding(param, mesh, range(len(param.shape)))

profile result is:

image

JackCaoG commented 3 months ago

I am not a GPU expert but my take is that communication across the host is a lot slower than within the host, that's most likely why you observe the performance difference. @vanbasten23 @yeounoh FYI

mars1248 commented 3 months ago

I am not a GPU expert but my take is that communication across the host is a lot slower than within the host, that's most likely why you observe the performance difference. @vanbasten23 @yeounoh FYI

@vanbasten23 @JackCaoG @yeounoh Thank you for your reply. I understand that the inter-machine communication will be much slower. My question is, is my spmd policy misconfigured, resulting in increased communication between machines? I want to use spmd to implement tensor parallel inside machine and data parallel between machines, like deepspeed。

JackCaoG commented 3 months ago

@jonb377 can probably give you more insights regarding how to achieve that, I think your situcaiton is a bit similar to multi-pod and you might want hyper-sharding.

jonb377 commented 3 months ago

Hey @mars1248!

Like @JackCaoG mentioned, this is achievable using SPMD, you'll just need to modify your mesh to group the different machines' devices on a different axis. For example:

device_ids = range(xr.global_runtime_device_count())
num_machine = 2
axis_names = ('replica', 'fsdp')
mesh_shape = (num_machine, len(device_ids) // num_machine) 

# Just a note - it's best practice to use a single mesh across the program and to reshape via the partition_spec.
mesh = xs.Mesh(device_ids, mesh_shape, axis_names)
for name, param in model.named_parameters():
    # By omitting the `replica` axis, the parameters are replicated along it.
    # This achieves FSDP within and data parallelism across machines.
    partition_spec = ('fsdp',) + (None,) * (len(param.shape) - 1)
    xs.mark_sharding(param, mesh, range(len(param.shape)))

Then, any batch-dimension sharding should operate over the combined ('replica', 'fsdp') axes:

# Example input_sharding, same mesh from above, assuming 2D input tensor:
input_sharding = xs.ShardingSpec(mesh, (('replica', 'fsdp'), None))

# Example activation sharding:
query_states = self.q_proj(hidden_states)
xs.mark_sharding(query_states, mesh, (('replica', 'fsdp'), None, None)

You can refer to our transformers fork for an example. There we support 2D sharding with replication across the dcn axis for multislice TPU training, so it's a bit more complicated than the example I gave. The fundamentals are the same.

You can also tune performance by setting XLA flags via the XLA_FLAGS environment variable. Here are some flags recommended for JAX, which should translate well to PyTorch/XLA: https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/PGLE.md#recommended-xla-flags

cc @vanbasten23 in case we have a set of PyTorch-specific flags.

mars1248 commented 3 months ago

@jonb377 Thank you for your answer. I found that it was due to tp inter-machine communication, resulting in performance issues。 I really only want to do intra-machine tp, like deepspeed does, but I would love to have a way to do inter-machine fsdp。 "Then, any batch-dimension sharding should operate over the combined ('replica', 'fsdp') axes" why need this step? https://github.com/huggingface/transformers/compare/main...pytorch-tpu:transformers:llama2-google-next-training I look at the cuda code in this link, it seems that there is no need to shard the input. And why do you need to shard activation? Is there any documentation to study?

jonb377 commented 3 months ago

With FSDP, each device should process distinct data, so the global batch dimension must be sharded across all devices, both for inputs and activations. Since the devices are grouped along both the replica and fsdp axes, these axes must be recombined using the ('replica', 'fsdp') tuple in the partition spec when sharding the batch dimension.

The fork I shared is for SPMD on TPUs, so the relevant part is the usage of the SPMD APIs. There aren't any explicit changes to the CUDA codepaths AFAIK, cc @vanbasten23 in case I'm wrong.

Here is a blog post discussing the usage of SPMD on TPUs: https://pytorch.org/blog/high-performance-llama-2/

If your goal is FSDP across machines (ie sharding each parameter across all global devices), you'll definitely want to tune XLA flags for more performant AllGather. The set I shared here is a good reference.

mars1248 commented 3 months ago

@jonb377 Thank you for your answer, but I still have some doubts

# Example input_sharding, same mesh from above, assuming 2D input tensor:
input_sharding = xs.ShardingSpec(mesh, (('replica', 'fsdp'), None))

# Example activation sharding:
query_states = self.q_proj(hidden_states)
xs.mark_sharding(query_states, mesh, (('replica', 'fsdp'), None, None)

Why do we need to sharding activation in the code above? What are the benefits? I understand that after sharding input and weight, shouldn't activation automatically be sharded as well?

jonb377 commented 3 months ago

Ah, that's a great question. To quote the JAX docs, While the compiler will attempt to decide how a function’s intermediate values and outputs should be sharded, we can also give it hints. It’s often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.

In summary, the compiler will try to determine the correct activation sharding based on the parameter and input shardings, but there are cases where adding explicit intermediate shardings will help. The batch-dimension sharding I recommended could be a case the compiler can already handle well, so leaving it out may not cause any issues.

@alanwaketan or @yeounoh feel free to add to the discussion.

mars1248 commented 3 months ago

@jonb377 Hello, I wrote this single test to simulate the logic of fsdp in xla. Why do I do all gather on the input tensor instead of the model weight? At the same time, I found that all-gather and compute cannot overlap when fsdp runs。 image

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
import numpy as np
import torch.optim as optim
import torch_xla.debug.profiler as xp
import time
import os
# Setup profiler env var
os.environ['XLA_HLO_DEBUG'] = '1'
server = xp.start_server(9012)
xr.use_spmd()

num_devices = xr.global_runtime_device_count()

mesh_shape = (num_devices // 1, 1)

device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))

t1 = torch.randn(1600, 12800, device='cpu')

xt1 = t1.to(xm.xla_device())
xs.mark_sharding(xt1, mesh, (('fsdp', 'replica'), ) + (None,))

model = torch.nn.Linear(12800, 9600)
model.to(xm.xla_device())
model2 = torch.nn.Linear(9600, 1280)
model2.to(xm.xla_device())
partition_spec = ('fsdp',) + (None,)
xs.mark_sharding(list(model.parameters())[0], mesh, partition_spec)
xs.mark_sharding(list(model2.parameters())[0], mesh, partition_spec)

for step in range(10000000):
    if step == 10:
        xp.trace_detached('localhost:9012', "/tmp/tb")
    with xp.StepTrace('train_loop', step_num=step):
        output = model(xt1)
        output2 = model2(output)
        if step == 20:
            break
jonb377 commented 3 months ago

It comes down to compiler decisions. In this case, it propagates sharding to the non-batch axis, then decides it's more efficient to allgather the activations instead of weights. You can check the compiler's sharding propagation decisions by setting XLA_FLAGS='--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_pass_re=spmd|propagation' in the environment.

I'm able to reproduce on TPU, and I see that the compiler is choosing to shard the non-batch dimension on the activation after sharding propagation:

ENTRY SyncTensorsGraph.21 {
  p2.4 = f32[1600,12800]{1,0} parameter(2), sharding={devices=[4,1]0,1,2,3}, metadata={...}
  p1.2 = f32[9600,12800]{1,0} parameter(1), sharding={devices=[4,1]0,1,2,3}, metadata={...}
  dot.5 = f32[1600,9600]{1,0} dot(p2.4, p1.2), lhs_contracting_dims={1}, rhs_contracting_dims={1}, (sharding={devices=[1,4]0,1,2,3}, metadata={...}
...

If you add some more sharding hints to the program (ie xs.mark_sharding(output, mesh, (('fsdp', 'replica'), 1))), the propagation will respect it and maintain batch-dimension sharding.

Regarding synchronous allgathers, activations necessarily depend on the result of the computation, so the poor compiler choice to gather activations instead of weights will block the gathers until the compute completes. Could you share the set of XLA flags you're using, since that can impact compiler choices?

I would also recommend checking out @alanwaketan's FSDPv2 wrapper, since it abstracts much of the complexity and automatically applies activation sharding annotations. You'll need a custom shard_output to shard across the combined ('fsdp', 'replica') axes, since by default the activation sharding applies only to the fsdp axis.

mars1248 commented 3 months ago

@jonb377 I used these flags

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_gpu_enable_async_all_gather=true \
--xla_gpu_enable_async_reduce_scatter=true \
--xla_gpu_enable_triton_gemm=false \
--xla_gpu_simplify_all_fp_conversions \
--xla_gpu_graph_level=0 \
--xla_gpu_enable_async_all_reduce=true \
--xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 \
--xla_gpu_all_gather_combine_threshold_bytes=1073741824 \
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824 \
--xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true \
--xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true \
--xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false \
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization
mars1248 commented 3 months ago

@jonb377 As you said, I have successfully aligned the full fsdp flow to the gpu. However, I found that the computation and communication are on the same steam, so there is no overlap. Is there any way to solve the problem? image

JackCaoG commented 3 months ago

I remember @Liyang90 told me that GPU SPMD put compute and communication in the same stream.

mars1248 commented 3 months ago

@JackCaoG @jonb377 @Liyang90 If they're all on the same stream, how do you overlap computation and communication? Does this need to be improved? At the same time, I would also like to ask how to separate the computation and communication into two streams in the non-SPMD training scenario. Which section can I refer to for the specific code?

JackCaoG commented 3 months ago

Chatted with @vanbasten23 and @jonb377 offline, I think there are flags that will enable the async collectives but we might need to try them out first.

huzama commented 4 weeks ago

I encountered a similar issue a while back. My code achieves 70% TPU utilization on a single TPU machine. However, when I scale up from one to 32 machines, the TPU utilization drops to 30%. The main contributors to the wasted time in v4-256 are all-gather-done, all-reduce-scatter-fusion, and collective-permute-done.

For your reference, I am using 2D sharding with SPMD.

I am not an expert in async collectives but I remeber reading it somewhere that it can improver performance significantly.

JackCaoG commented 4 weeks ago

Can you share the profile(xplane file)? It looks like you model is communication bound, which might be because the per device batch size is small and there is no enough compute to overlap with the communication.

huzama commented 4 weeks ago

@JackCaoG I conducted several experiments by varying the batch size. The attached file contains profile for a batch size of 256. Interestingly, even when I increased the batch size to 1024, the utilization rate remained unchanged. localhost_9012.xplane.pb.zip

JackCaoG commented 3 weeks ago

@huzama I looked at your profile and most of the conv fusion has the output shape

bf16[8,767,4096]

and all-gather happens like

%all-gather.813 = bf16[8,767,4096]{2,1,0:T(8,128)(2,1)S(3)} all-gather(bf16[8,767,1024]

seem like these tensors are only 4 way sharded and replicated on other dimension. How is your device mesh looks like? You said you scale up to 32 machines which suggest you have 128 devices, is you mesh (32, 4)?

Also [8,767,4096] is kind of small. On similar scale our matmul/conv is usually a lot larger.

Lastly do you have to use 2d sharding? You are likely get better perfomracen using the FSDPv2 wrapper we provided. Check https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py#L54-L55. We only recommend 2d sharding if the FSDPv2 OOM even with per-device-batch-size 1. For llama2 we can easily achieve 50%+ MFU at this scale with FSDPv2.

huzama commented 3 weeks ago

@JackCaoG, you are right. My mesh shape for this profile file is (4 - model, 32 - data). I have also tried different configurations such as (2, 64) and (8, 16), but they did not make a difference.

I am using 2D SPMD and following all the steps from this blog post. Additionally, I have used a global batch size of up to 512, resulting in a per-chip batch size of 16 with (4, 32) sharding. Changing the batch size did not make a difference either.

All I could deduce from these experiments is that the bottleneck lies somewhere else.

I also tried using FSDPv2, which resulted in 40% MFU. This is better but still not exceeding the 50%+ MFU.

Additionally, for 2D SPMD, I have marked activations for sharding and replaced the linear layer with einsum attention. I am using the LLAMA model and have tried different model sizes, but nothing changes that 30% MFU.

Any insights or suggestions would be greatly appreciated!

JackCaoG commented 3 weeks ago

I think your model eventually becomes communication bound because compute is too small. Are you running llama2 7B? The hidden size and intermediate size seems small. In the FSDP v2 we need to overlap the all_gather of the parameter with the compute from the previous step. I think you option is to either keep increasing the batch size(might not be ideal if you are concern about the training quality) or uses a larger hidden dimension and intermediate dimension for linear layer.

JackCaoG commented 3 weeks ago

You can also use the flash attention with FSDP following my example in https://github.com/pytorch/xla/blob/master/examples/flash_attention/train_decoder_only_flash_attention_fsdp_v2.py, but again your model is communication bound so I doubt it will make your e2e time much faster. You should also open another issue since the original topic is about GPU.

huzama commented 3 weeks ago

Thank you @JackCaoG, I'll check it out and create a new issue.