meta-llama / llama-recipes

Scripts for fine-tuning Meta Llama with composable FSDP & PEFT methods to cover single/multi-node GPUs. Supports default & custom datasets for applications such as summarization and Q&A. Supporting a number of candid inference solutions such as HF TGI, VLLM for local or cloud deployment. Demo apps to showcase Meta Llama for WhatsApp & Messenger.
15.18k stars 2.19k forks source link

Long context #785

Closed Amerehei closed 4 days ago

Amerehei commented 5 days ago

When I increase --context_length no matter how many GPU is available, it ends with OutOfMemoryError error

I'm wondering why fully sharded FSDP doesn't work as expected

torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning.py --context_length 50000 --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels

How can I train a model with more than 50K context length

wukaixingxp commented 5 days ago

@Amerehei May I know what GPU you are using and how many GPU you have? The fine-tune needs a lot memory.. and 50K context needs insane amount of GPU memory, you can try to use this GPU memory tool to calculate. BTW, our llama 3.1 model already supports 128K context length and you do not need to fine-tune to get that..

wukaixingxp commented 5 days ago

what is your command? can you send me the error log?

Amerehei commented 4 days ago

This is the output of it using Python 3.11.10 and Pytorch 2.5.1 (I tested 2.4 as well)

pip list
Package                                  Version
---------------------------------------- --------------
absl-py                                  2.1.0
accelerate                               1.1.1
aiohappyeyeballs                         2.4.3
aiohttp                                  3.10.10
aiosignal                                1.3.1
annotated-types                          0.7.0
antlr4-python3-runtime                   4.9.3
anyio                                    4.6.2.post1
appdirs                                  1.4.4
argon2-cffi                              23.1.0
argon2-cffi-bindings                     21.2.0
arrow                                    1.3.0
asttokens                                2.4.1
async-lru                                2.0.4
attrs                                    24.2.0
babel                                    2.16.0
backoff                                  2.2.1
beautifulsoup4                           4.12.3
bitsandbytes                             0.44.1
black                                    24.10.0
bleach                                   6.2.0
blinker                                  1.4
boltons                                  21.0.0
bracex                                   2.5.post1
Brotli                                   1.1.0
cachetools                               5.5.0
certifi                                  2024.8.30
cffi                                     1.17.1
chardet                                  5.2.0
charset-normalizer                       3.4.0
click                                    8.1.7
click-option-group                       0.5.6
codeshield                               1.0.1
colorama                                 0.4.6
coloredlogs                              15.0.1
comm                                     0.2.2
contourpy                                1.3.1
cryptography                             43.0.3
cycler                                   0.12.1
dataclasses-json                         0.6.7
datasets                                 3.1.0
dbus-python                              1.2.18
debugpy                                  1.8.8
decorator                                5.1.1
defusedxml                               0.7.1
Deprecated                               1.2.14
dill                                     0.3.8
distro                                   1.7.0
effdet                                   0.4.1
emoji                                    2.14.0
entrypoints                              0.4
eval_type_backport                       0.2.0
evaluate                                 0.4.3
exceptiongroup                           1.2.2
executing                                2.1.0
face                                     24.0.0
fastjsonschema                           2.20.0
filelock                                 3.13.1
filetype                                 1.2.0
fire                                     0.7.0
flatbuffers                              24.3.25
fonttools                                4.54.1
fqdn                                     1.5.1
frozenlist                               1.5.0
fsspec                                   2024.2.0
glom                                     22.1.0
google-api-core                          2.23.0
google-auth                              2.36.0
google-cloud-vision                      3.8.0
googleapis-common-protos                 1.66.0
grpcio                                   1.67.1
grpcio-status                            1.62.3
h11                                      0.14.0
httpcore                                 1.0.6
httplib2                                 0.20.2
httpx                                    0.27.2
huggingface-hub                          0.26.2
humanfriendly                            10.0
idna                                     3.10
importlib_metadata                       7.1.0
inflate64                                1.0.0
iopath                                   0.1.10
ipykernel                                6.29.5
ipython                                  8.29.0
ipython-genutils                         0.2.0
ipywidgets                               8.1.5
isoduration                              20.11.0
jedi                                     0.19.2
jeepney                                  0.7.1
Jinja2                                   3.1.3
joblib                                   1.4.2
json5                                    0.9.28
jsonpath-python                          1.0.6
jsonpointer                              3.0.0
jsonschema                               4.23.0
jsonschema-specifications                2024.10.1
jupyter-archive                          3.4.0
jupyter_client                           7.4.9
jupyter_contrib_core                     0.4.2
jupyter_contrib_nbextensions             0.7.0
jupyter_core                             5.7.2
jupyter-events                           0.10.0
jupyter-highlight-selected-word          0.2.0
jupyter-lsp                              2.2.5
jupyter_nbextensions_configurator        0.6.4
jupyter_server                           2.14.2
jupyter_server_terminals                 0.5.3
jupyterlab                               4.2.5
jupyterlab_pygments                      0.3.0
jupyterlab_server                        2.27.3
jupyterlab_widgets                       3.0.13
keyring                                  23.5.0
kiwisolver                               1.4.7
langdetect                               1.0.9
launchpadlib                             1.10.16
layoutparser                             0.3.4
lazr.restfulclient                       0.14.4
lazr.uri                                 1.0.6
llama-recipes                            0.0.4.post1
loralib                                  0.1.2
lxml                                     5.3.0
markdown-it-py                           3.0.0
MarkupSafe                               2.1.5
marshmallow                              3.23.1
matplotlib                               3.9.2
matplotlib-inline                        0.1.7
mdurl                                    0.1.2
mistune                                  3.0.2
more-itertools                           8.10.0
mpmath                                   1.3.0
multidict                                6.1.0
multiprocess                             0.70.16
multivolumefile                          0.2.3
mypy-extensions                          1.0.0
nbclassic                                1.1.0
nbclient                                 0.10.0
nbconvert                                7.16.4
nbformat                                 5.10.4
nest-asyncio                             1.6.0
networkx                                 3.2.1
nltk                                     3.9.1
notebook                                 6.5.5
notebook_shim                            0.2.4
numpy                                    1.26.3
nvidia-cublas-cu12                       12.4.5.8
nvidia-cuda-cupti-cu12                   12.4.127
nvidia-cuda-nvrtc-cu12                   12.4.127
nvidia-cuda-runtime-cu12                 12.4.127
nvidia-cudnn-cu12                        9.1.0.70
nvidia-cufft-cu12                        11.2.1.3
nvidia-curand-cu12                       10.3.5.147
nvidia-cusolver-cu12                     11.6.1.9
nvidia-cusparse-cu12                     12.3.1.170
nvidia-nccl-cu12                         2.21.5
nvidia-nvjitlink-cu12                    12.4.127
nvidia-nvtx-cu12                         12.4.127
oauthlib                                 3.2.0
omegaconf                                2.3.0
onnx                                     1.17.0
onnxruntime                              1.20.0
openai                                   1.38.0
opencv-python                            4.10.0.84
opentelemetry-api                        1.25.0
opentelemetry-exporter-otlp-proto-common 1.25.0
opentelemetry-exporter-otlp-proto-http   1.25.0
opentelemetry-instrumentation            0.46b0
opentelemetry-instrumentation-requests   0.46b0
opentelemetry-proto                      1.25.0
opentelemetry-sdk                        1.25.0
opentelemetry-semantic-conventions       0.46b0
opentelemetry-util-http                  0.46b0
optimum                                  1.23.3
overrides                                7.7.0
packaging                                24.2
pandas                                   2.2.3
pandocfilters                            1.5.1
parso                                    0.8.4
pathspec                                 0.12.1
pdf2image                                1.17.0
pdfminer.six                             20231228
pdfplumber                               0.11.4
peewee                                   3.17.8
peft                                     0.13.2
pexpect                                  4.9.0
pi_heif                                  0.20.0
pikepdf                                  9.4.1
pillow                                   10.2.0
pip                                      24.3.1
platformdirs                             4.3.6
portalocker                              2.10.1
prometheus_client                        0.21.0
prompt_toolkit                           3.0.48
propcache                                0.2.0
proto-plus                               1.25.0
protobuf                                 4.25.5
psutil                                   6.1.0
ptyprocess                               0.7.0
pure_eval                                0.2.3
py7zr                                    0.22.0
pyarrow                                  18.0.0
pyasn1                                   0.6.1
pyasn1_modules                           0.4.1
pybcj                                    1.0.2
pycocotools                              2.0.8
pycparser                                2.22
pycryptodomex                            3.21.0
pydantic                                 2.9.2
pydantic_core                            2.23.4
Pygments                                 2.18.0
PyGObject                                3.42.1
PyJWT                                    2.3.0
pyparsing                                2.4.7
pypdf                                    5.1.0
pypdfium2                                4.30.0
pyppmd                                   1.1.0
python-apt                               2.4.0+ubuntu4
python-dateutil                          2.8.2
python-iso639                            2024.10.22
python-json-logger                       2.0.7
python-magic                             0.4.27
python-multipart                         0.0.17
pytz                                     2024.2
PyYAML                                   6.0.1
pyzmq                                    24.0.1
pyzstd                                   0.16.2
RapidFuzz                                3.10.1
referencing                              0.35.1
regex                                    2024.11.6
requests                                 2.32.3
requests-toolbelt                        1.0.0
rfc3339-validator                        0.1.4
rfc3986-validator                        0.1.1
rich                                     13.5.3
rouge_score                              0.1.2
rpds-py                                  0.21.0
rsa                                      4.9
ruamel.yaml                              0.17.40
ruamel.yaml.clib                         0.2.12
safetensors                              0.4.5
scikit-learn                             1.5.2
scipy                                    1.14.1
SecretStorage                            3.3.1
semgrep                                  1.96.0
Send2Trash                               1.8.3
sentence-transformers                    3.3.0
sentencepiece                            0.2.0
setuptools                               75.4.0
six                                      1.16.0
sniffio                                  1.3.1
soupsieve                                2.6
stack-data                               0.6.3
sympy                                    1.13.1
tabulate                                 0.9.0
termcolor                                2.5.0
terminado                                0.18.1
texttable                                1.7.0
threadpoolctl                            3.5.0
timm                                     1.0.11
tinycss2                                 1.4.0
tokenize_rt                              6.1.0
tokenizers                               0.20.3
tomli                                    2.0.2
torch                                    2.5.1+cu124
torchaudio                               2.5.1+cu124
torchvision                              0.20.1+cu124
tornado                                  6.4.1
tqdm                                     4.67.0
traitlets                                5.14.3
transformers                             4.46.2
triton                                   3.1.0
types-python-dateutil                    2.9.0.20241003
typing_extensions                        4.8.0
typing-inspect                           0.9.0
tzdata                                   2024.2
unstructured                             0.15.8
unstructured-client                      0.27.0
unstructured-inference                   0.7.36
unstructured.pytesseract                 0.3.13
uri-template                             1.3.0
urllib3                                  2.2.3
wadllib                                  1.3.6
wcmatch                                  8.5.2
wcwidth                                  0.2.13
webcolors                                24.11.1
webencodings                             0.5.1
websocket-client                         1.8.0
wheel                                    0.45.0
widgetsnbextension                       4.0.13
wrapt                                    1.16.0
xxhash                                   3.5.0
yarl                                     1.17.1
zipp                                     1.0.0
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:18:05_PDT_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0
nvidia-smi
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.6     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A40                     On  | 00000000:53:00.0 Off |                    0 |
|  0%   40C    P0             208W / 300W |  12700MiB / 46068MiB |     94%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A40                     On  | 00000000:57:00.0 Off |                    0 |
|  0%   38C    P0             194W / 300W |  12700MiB / 46068MiB |     94%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A40                     On  | 00000000:98:00.0 Off |                    0 |
|  0%   41C    P0             196W / 300W |  12676MiB / 46068MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A40                     On  | 00000000:CE:00.0 Off |                    0 |
|  0%   38C    P0             193W / 300W |  12700MiB / 46068MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A40                     On  | 00000000:D1:00.0 Off |                    0 |
|  0%   41C    P0             205W / 300W |  12724MiB / 46068MiB |     99%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A40                     On  | 00000000:D2:00.0 Off |                    0 |
|  0%   39C    P0             212W / 300W |  12724MiB / 46068MiB |     99%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A40                     On  | 00000000:D6:00.0 Off |                    0 |
|  0%   41C    P0             213W / 300W |  12654MiB / 46068MiB |     99%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

I run the following command, here I use a very small dataset just for testing (I tried the default dataset as well)

Using a single GPU, I can increase context_length up to 22000 and it fails with 23000

torchrun --nnodes 1 \
--nproc_per_node 1  \
finetuning.py \
--model_name meta-llama/Llama-3.2-1B \
--context_length 22000 \
--num_epochs 1 \
--dataset "custom_dataset" \
--custom_dataset.file "custom_dataset.py" \
--enable_fsdp \
--use_peft \
--peft_method lora \
--output_dir ./output \
--batch_size_training 1

When I increase nproc_per_node to 7, I can increase context_length to to 24000 and it fails with 25000. Is it normal? Why a single GPU can handle 22000 tokens and adding 6 more GPUs only increase the context size by 9%?

torchrun --nnodes 1 \
--nproc_per_node 7  \
finetuning.py \
--model_name meta-llama/Llama-3.2-1B \
--context_length 25000 \
--num_epochs 1 \
--dataset "custom_dataset" \
--custom_dataset.file "custom_dataset.py" \
--enable_fsdp \
--use_peft \
--peft_method lora \
--output_dir ./output \
--batch_size_training 1
Log
W1113 14:36:19.328000 23324 torch/distributed/run.py:793] 
W1113 14:36:19.328000 23324 torch/distributed/run.py:793] *****************************************
W1113 14:36:19.328000 23324 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1113 14:36:19.328000 23324 torch/distributed/run.py:793] *****************************************
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
--> Model meta-llama/Llama-3.2-1B

--> meta-llama/Llama-3.2-1B has 1235.8144 Million params

trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
bFloat16 enabled for mixed precision - using bfSixteen policy
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689
--> applying fsdp activation checkpointing...
--> Training Set Length = 35
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Validation Set Length = 9
Preprocessing dataset:   0%|                                                                                                                                                        | 0/35 [00:00 applying fsdp activation checkpointing...
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.75it/s]
length of dataset_train 37
Preprocessing dataset:  60%|█████████████████████████████████████████████████████████████████████████████████████▊                                                         | 21/35 [00:00<00:00, 27.61it/s]Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.53it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.32it/s]
length of dataset_train 37
Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.67it/s]
length of dataset_train 37
Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 29.07it/s]
length of dataset_train 37
Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 26.84it/s]
length of dataset_train 37
Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset:   0%|                                                                                                                                                         | 0/9 [00:00 Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.12it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.01it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.89it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 26.90it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 26.90it/s]
length of dataset_train 37
Can not find the custom data_collator in the dataset.py file (custom_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 5
Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.69it/s]
--> Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
Preprocessing dataset:  33%|████████████████████████████████████████████████▎                                                                                                | 3/9 [00:00<00:00, 23.29it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                             | 0/5 [00:00 Num of Validation Set Batches loaded = 1
--> Num of Validation Set Batches loaded = 1
Starting epoch 0/1
train_config.max_train_step: 0
/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                             | 0/5 [00:00
[rank6]:     fire.Fire(main)
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire
[rank6]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank6]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire
[rank6]:     component, remaining_args = _CallAndUpdateTrace(
[rank6]:                                 ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank6]:     component = fn(*varargs, **kwargs)
[rank6]:                 ^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/llama_recipes/finetuning.py", line 311, in main
[rank6]:     results = train(
[rank6]:               ^^^^^^
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/llama_recipes/utils/train_utils.py", line 175, in train
[rank6]:     loss.backward()
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 581, in backward
[rank6]:     torch.autograd.backward(
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank6]:     _engine_run_backward(
[rank6]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank6]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.95 GiB. GPU 6 has a total capacity of 44.35 GiB of which 4.98 GiB is free. Process 1595066 has 39.36 GiB memory in use. Of the allocated memory 26.45 GiB is allocated by PyTorch, and 12.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Training Epoch: 1/1, step 0/5 completed (loss: 0.611149251461029):  20%|████████████████████▏                                                                                | 1/5 [00:11<00:44, 11.23s/it]
W1113 14:36:44.924000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23392 closing signal SIGTERM
W1113 14:36:44.925000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23393 closing signal SIGTERM
W1113 14:36:44.926000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23394 closing signal SIGTERM
W1113 14:36:44.927000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23395 closing signal SIGTERM
W1113 14:36:44.928000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23396 closing signal SIGTERM
W1113 14:36:44.929000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23397 closing signal SIGTERM
E1113 14:36:46.111000 23324 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 6 (pid: 23398) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in 
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
finetuning.py FAILED
------------------------------------------------------------
Failures:
  
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-11-13_14:36:44
  host      : a7552316fef1
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 23398)
  error_file: 
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

wukaixingxp commented 4 days ago

Hi! FSDP is mostly used when a single model can not fit into a single GPU, and it will not increase the training context length. There is a Context Parallel feature that has been implemented by Torchtitan where the context would be processed/parallelized and adding more GPU's would make an improvement. Hopefully, this can help with your problem.

lessw2020 commented 4 days ago

Hi @Amerehei, With FSDP, you are not sharding the growing activations from the increased context length so as you increase the context length, your activations will continue to balloon and the limit will continue to be what 1 gpu can handle. You are getting a tiny gain by sharding across more gpus b/c - the amount each GPU holds for the model weights/optimizers/gradients will reduce a bit more, but that only creates a slightly bit more room to hold the larger unsharded sequence activations..as you saw when adding 7 more gpus only increased your OOM limit by 9%.

For longer context windows, you need to add in another dimension of sharding which in this case is what @wukaixingxp is noting above, context parallel.
This sharding would then allow you to train with larger and larger context windows as the sequence (context) activations are being sharded across gpus. Titan is focused on pre-training though and here you are doing fine tuning, so we'd need to see about adding context parallel here in llama-recipes as the fix to allow you to expand to larger contexts by adding more gpus to the world size.

For reference, context parallel is a new feature that was just landed a few weeks ago in Titan, but this chart shows you it's impact on supporting longer sequence lengths:

Screenshot 2024-11-13 at 7 21 11 PM
Amerehei commented 4 days ago

@lessw2020 @wukaixingxp Thanks for the information