Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.21k stars 1.01k forks source link

Extreme slow-down when using FSDP #607

Closed GrauleM closed 2 months ago

GrauleM commented 1 year ago

I was struggling to get LLaMA2 inference to run well on multiple GPUs. After playing around a bit, I observed an extreme slow-down when using FSDP strategy and 2+ devices.

How to reproduce the problem:

python generate/base.py  --checkpoint_dir 'path_to_llama2/Llama-2-7b-hf' --devices 1 --strategy auto

-> Token generation rate: 33.42 tokens/sec

python generate/base.py  --checkpoint_dir 'path_to_llama2/Llama-2-7b-hf' --devices 2 --strategy fsdp

-> Token generation rate: 2.54 tokens/sec

python generate/base.py  --checkpoint_dir 'path_to_llama2/Llama-2-7b-hf' --devices 4 --strategy fsdp

-> Token generation rate: 3.34 tokens/sec

I observed similar trends for the other models.

My setup: 4x NVIDIA RTX A5000; 24GB Ubuntu 22.04.2 conda environment with Python 3.10, CUDA 11.8

Andrei-Aksionov commented 11 months ago

I can confirm the same behavior with:

With auto strategy (DDPStrategy) the speed is 17.30 token/sec:

main ~/lit-gpt python generate/base.py --checkpoint_dir checkpoints/EleutherAI/pythia-2.8b --devices 4 --strategy auto
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

Loading model 'checkpoints/EleutherAI/pythia-2.8b/lit_model.pth' with {'name': 'pythia-2.8b', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-2.8b'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 32, 'n_head': 32, 'n_embd': 2560, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 32, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 80, 'rope_n_elem': 20}
Time to instantiate model: 1.38 seconds.
[rank: 2] Seed set to 1234
[rank: 3] Seed set to 1234
Time to load the model weights: 4.73 seconds.
[rank: 0] Seed set to 1234
[rank: 1] Seed set to 1234
Hello, my name is Brian Patrick Lowenthal IV.</legend>\nest html:</legend>\nest html:</legend>"
          <!-- END COMMENT
          -->
          "</figure>\nest html:</figure>\nest html:</figure>",
          null
          });

Time for inference 1: 2.89 sec total, 17.30 tokens/sec
Memory used: 11.16 GB

... and with fspd strategy it's 0.77 token/sec:

main ~/lit-gpt python generate/base.py --checkpoint_dir checkpoints/EleutherAI/pythia-2.8b --devices 4 --strategy fsdp
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

Loading model 'checkpoints/EleutherAI/pythia-2.8b/lit_model.pth' with {'name': 'pythia-2.8b', 'hf_config': {'org': 'EleutherAI', 'name': 'pythia-2.8b'}, 'block_size': 2048, 'vocab_size': 50254, 'padding_multiple': 128, 'padded_vocab_size': 50304, 'n_layer': 32, 'n_head': 32, 'n_embd': 2560, 'rotary_percentage': 0.25, 'parallel_residual': True, 'bias': True, 'lm_head_bias': False, 'n_query_groups': 32, 'shared_attention_norm': False, '_norm_class': 'LayerNorm', 'norm_eps': 1e-05, '_mlp_class': 'GptNeoxMLP', 'gelu_approximate': 'none', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'head_size': 80, 'rope_n_elem': 20}
Time to instantiate model: 0.04 seconds.
[rank3]:[2023-11-08 14:22:19,297] torch.distributed.fsdp._debug_utils: [WARNING] FSDP model load_state_dict profiling:  defaultdict(<class 'float'>, {'_enter_unshard_params_ctx': 0.023662725000576756, '_exit_unshard_params_ctx': 0.020300298000620387})
[rank2]:[2023-11-08 14:22:19,306] torch.distributed.fsdp._debug_utils: [WARNING] FSDP model load_state_dict profiling:  defaultdict(<class 'float'>, {'_enter_unshard_params_ctx': 0.02244255800087558, '_exit_unshard_params_ctx': 0.01876510300053269})
[rank1]:[2023-11-08 14:22:19,308] torch.distributed.fsdp._debug_utils: [WARNING] FSDP model load_state_dict profiling:  defaultdict(<class 'float'>, {'_enter_unshard_params_ctx': 0.021965552999972715, '_exit_unshard_params_ctx': 0.01884183099900838})
[rank0]:[2023-11-08 14:22:19,318] torch.distributed.fsdp._debug_utils: [WARNING] FSDP model load_state_dict profiling:  defaultdict(<class 'float'>, {'_enter_unshard_params_ctx': 0.020666511000399623, '_exit_unshard_params_ctx': 0.017720086999815976})
Time to load the model weights: 6.55 seconds.
[rank: 0] Seed set to 1234
[rank: 1] Seed set to 1234
[rank: 3] Seed set to 1234
[rank: 2] Seed set to 1234
Hello, my name is Mabel.

I’m a new student at the university.
I’m very friendly and nice.

I’m a fashion model.
My favorite color is blue.

I’m very kind.
I
Time for inference 1: 64.82 sec total, 0.77 tokens/sec
Memory used: 2.44 GB

I understand that there is a communication overhead, still it's a too dramatic drop in the speed of token generation. @awaelchli You have a lot of experience with multi-GPU training. Could you comment on it?

carmocca commented 11 months ago

A 7B model on half precision will require a minimum of 7(billion)2(bytes)=14(GB/s)8(bits/byte)=112(Gb/s) of communication every forward call with FSDP during inference.

This assumes that the model is completely sharded, in practice, it would be less because only the transformer blocks are sharded.

So depending on the intra-node connectivity, FSDP might not be feasible for models in the billions of parameters range.

You can check the connectivity by running nvidia-smi -m topo

For instance, a result like

        GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity
GPU0     X      PHB     PHB     PHB     0-47            N/A
GPU1    PHB      X      PHB     PHB     0-47            N/A
GPU2    PHB     PHB      X      PHB     0-47            N/A
GPU3    PHB     PHB     PHB      X      0-47            N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Shows that it's using PCIe. Now, checking the PCI-e version, on my machine returns

...
Type: x4 PCI Express 3 x4

Which according to https://www.kingston.com/en/blog/pc-performance/pcie-gen-4-explained corresponds to 4 GB/s, making communication a huge bottleneck.

So I suggest that you do a similar analysis on your setup. This is most likely the cause of the issue.

Andrei-Aksionov commented 11 months ago

Thanks @carmocca for such a detailed explanation!

Only one correction: the command is nvidia-smi topo -m.

I can confirm that in case of my machine with 4x A10G I have a slow connection:

        GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity
GPU0     X      PHB     PHB     PHB     0-47            N/A
GPU1    PHB      X      PHB     PHB     0-47            N/A
GPU2    PHB     PHB      X      PHB     0-47            N/A
GPU3    PHB     PHB     PHB      X      0-47            N/A