huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.97k stars 970 forks source link

feat: support tensor parallel using Pytorch 2.0 & Data loader #3173

Open kmehant opened 1 month ago

kmehant commented 1 month ago

What does this PR do?

  1. Implements TorchTensorParallelPlugin to support TP with Pytorch 2.0. This work should be seen along with the PR https://github.com/huggingface/transformers/pull/34194.
  2. https://github.com/huggingface/transformers/pull/34184
  3. Modifies dataloader to support passing same samples across TP ranks

Please review in conjunction with https://github.com/huggingface/transformers/pull/34194

Results

See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.

Done on two models

  1. ibm-granite/granite-8b-code-base-128k
  2. codellama/CodeLlama-7b-hf

Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.

Note: Please be aware that the effective TPS for FSDP would be multiplicative of the parallel factor (number of GPUs/devices engaged in distributed training) whereas that is not the case with TP. Therefore, when effective throughput is considered we can find FSDP is better than TP in terms of throughput. However, that may be compensated by increasing the batch size utilizing the memory gains etc.

Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 FALSE 52.4 7675.4
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 8192 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 8192 1 TRUE 29.975586 2256.896
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 8192 1 TRUE 26.5 5935.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 FALSE OOM NA
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
ibm-granite/granite-8b-code-base-128k Single GPU non-distributed 1 16384 1 TRUE OOM NA
ibm-granite/granite-8b-code-base-128k FSDP 4 16384 1 TRUE 36.8 2084.864
ibm-granite/granite-8b-code-base-128k TP (This PR) 4 16384 1 TRUE 33.5 5692.5
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 8192 1 FALSE 70.7 3560
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 FALSE 42.8 9216
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 8192 1 TRUE 75.3 2849
codellama/CodeLlama-7b-hf FSDP 4 8192 1 TRUE 26.4 5957
codellama/CodeLlama-7b-hf TP (This PR) 4 8192 1 TRUE 21.4 7125
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf FSDP 4 16384 1 FALSE OOM NA
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 FALSE OOM NA
Model Method # of GPUs Context Length Batch Size Grad Checkpointing Cuda Max Mem (GiB) Tokens/Sec/GPU
codellama/CodeLlama-7b-hf Single GPU non-distributed 1 16384 1 TRUE 75.3 2599
codellama/CodeLlama-7b-hf FSDP 4 16384 1 TRUE 30.1 2433
codellama/CodeLlama-7b-hf TP (This PR) 4 16384 1 TRUE 26.6 6873

Fixes # (issue) https://github.com/huggingface/transformers/issues/32470

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you

HuggingFaceDocBuilderDev commented 3 weeks ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

muellerzr commented 3 weeks ago

@kmehant if you rebase from main this should fix the failures (tl;dr we had py 3.8 EOL)

kmehant commented 3 weeks ago

@muellerzr Appreciate your response. I would like to bring to your notice the below two points.

  1. This dataloader written to work for the paradigm (call it paradigm 1) of master process fetching the data needed and distributing them to all the worker processes. The more general paradigm (call it paradigm 2) of all the processes fetching their own data sample in TP case it has to be the same batch across the processes is not covered in this PR.
  2. This PR has a soft dependency to apply TP plan over the model since this PR is more like of 2 parts - TP workflow through accelerate plugin + dataloader.
    1. First part of the PR applies TP parallelism to the model like shown here - https://github.com/huggingface/accelerate/pull/3173/files#diff-2d7515874eaecac2687c7fc1a9c720be53f802bf14b4c3dcebe14ad443d075dcR1467 creating a soft dependency over https://github.com/huggingface/transformers/pull/34194 (Part of this would be superseded by https://github.com/huggingface/transformers/pull/34184 that is carrying a different interface to apply TP plan to the model).
    2. second part is the dataloader

For point (1) I can keep this PR simple and allow only for the paradigm 1 and address the paradigm 2 in another PR. For point (2) I can remove application of TP part from this PR, keeping this simple and independent. The part removed can be added in a separate PR as point (2)(i) is completed.

WDYT?

kmehant commented 2 weeks ago

@muellerzr can I work on this https://github.com/huggingface/accelerate/pull/3173#pullrequestreview-2401793359 in a separate PR?

I have fetched and rebased my PR and addressed all the review comments thank you.

HoangCongDuc commented 5 days ago

This feature is really useful, thank you @kmehant. I wonder if it is possible to combine tensor parallel with data parallel after this PR, say, TP for same-node parallelism and DP for multi-node parallelism.

kmehant commented 5 days ago

This feature is really useful, thank you @kmehant. I wonder if it is possible to combine tensor parallel with data parallel after this PR, say, TP for same-node parallelism and DP for multi-node parallelism.

Hi @HoangCongDuc, support for that is in my TODOs but not covered in this PR, should be coming soon after discussing with HF. Thank you.