Closed Amerehei closed 4 days ago
@Amerehei May I know what GPU you are using and how many GPU you have? The fine-tune needs a lot memory.. and 50K context needs insane amount of GPU memory, you can try to use this GPU memory tool to calculate. BTW, our llama 3.1 model already supports 128K context length and you do not need to fine-tune to get that..
what is your command? can you send me the error log?
This is the output of it using Python 3.11.10 and Pytorch 2.5.1 (I tested 2.4 as well)
Package Version ---------------------------------------- -------------- absl-py 2.1.0 accelerate 1.1.1 aiohappyeyeballs 2.4.3 aiohttp 3.10.10 aiosignal 1.3.1 annotated-types 0.7.0 antlr4-python3-runtime 4.9.3 anyio 4.6.2.post1 appdirs 1.4.4 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 async-lru 2.0.4 attrs 24.2.0 babel 2.16.0 backoff 2.2.1 beautifulsoup4 4.12.3 bitsandbytes 0.44.1 black 24.10.0 bleach 6.2.0 blinker 1.4 boltons 21.0.0 bracex 2.5.post1 Brotli 1.1.0 cachetools 5.5.0 certifi 2024.8.30 cffi 1.17.1 chardet 5.2.0 charset-normalizer 3.4.0 click 8.1.7 click-option-group 0.5.6 codeshield 1.0.1 colorama 0.4.6 coloredlogs 15.0.1 comm 0.2.2 contourpy 1.3.1 cryptography 43.0.3 cycler 0.12.1 dataclasses-json 0.6.7 datasets 3.1.0 dbus-python 1.2.18 debugpy 1.8.8 decorator 5.1.1 defusedxml 0.7.1 Deprecated 1.2.14 dill 0.3.8 distro 1.7.0 effdet 0.4.1 emoji 2.14.0 entrypoints 0.4 eval_type_backport 0.2.0 evaluate 0.4.3 exceptiongroup 1.2.2 executing 2.1.0 face 24.0.0 fastjsonschema 2.20.0 filelock 3.13.1 filetype 1.2.0 fire 0.7.0 flatbuffers 24.3.25 fonttools 4.54.1 fqdn 1.5.1 frozenlist 1.5.0 fsspec 2024.2.0 glom 22.1.0 google-api-core 2.23.0 google-auth 2.36.0 google-cloud-vision 3.8.0 googleapis-common-protos 1.66.0 grpcio 1.67.1 grpcio-status 1.62.3 h11 0.14.0 httpcore 1.0.6 httplib2 0.20.2 httpx 0.27.2 huggingface-hub 0.26.2 humanfriendly 10.0 idna 3.10 importlib_metadata 7.1.0 inflate64 1.0.0 iopath 0.1.10 ipykernel 6.29.5 ipython 8.29.0 ipython-genutils 0.2.0 ipywidgets 8.1.5 isoduration 20.11.0 jedi 0.19.2 jeepney 0.7.1 Jinja2 3.1.3 joblib 1.4.2 json5 0.9.28 jsonpath-python 1.0.6 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2024.10.1 jupyter-archive 3.4.0 jupyter_client 7.4.9 jupyter_contrib_core 0.4.2 jupyter_contrib_nbextensions 0.7.0 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-highlight-selected-word 0.2.0 jupyter-lsp 2.2.5 jupyter_nbextensions_configurator 0.6.4 jupyter_server 2.14.2 jupyter_server_terminals 0.5.3 jupyterlab 4.2.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.13 keyring 23.5.0 kiwisolver 1.4.7 langdetect 1.0.9 launchpadlib 1.10.16 layoutparser 0.3.4 lazr.restfulclient 0.14.4 lazr.uri 1.0.6 llama-recipes 0.0.4.post1 loralib 0.1.2 lxml 5.3.0 markdown-it-py 3.0.0 MarkupSafe 2.1.5 marshmallow 3.23.1 matplotlib 3.9.2 matplotlib-inline 0.1.7 mdurl 0.1.2 mistune 3.0.2 more-itertools 8.10.0 mpmath 1.3.0 multidict 6.1.0 multiprocess 0.70.16 multivolumefile 0.2.3 mypy-extensions 1.0.0 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.2.1 nltk 3.9.1 notebook 6.5.5 notebook_shim 0.2.4 numpy 1.26.3 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 oauthlib 3.2.0 omegaconf 2.3.0 onnx 1.17.0 onnxruntime 1.20.0 openai 1.38.0 opencv-python 4.10.0.84 opentelemetry-api 1.25.0 opentelemetry-exporter-otlp-proto-common 1.25.0 opentelemetry-exporter-otlp-proto-http 1.25.0 opentelemetry-instrumentation 0.46b0 opentelemetry-instrumentation-requests 0.46b0 opentelemetry-proto 1.25.0 opentelemetry-sdk 1.25.0 opentelemetry-semantic-conventions 0.46b0 opentelemetry-util-http 0.46b0 optimum 1.23.3 overrides 7.7.0 packaging 24.2 pandas 2.2.3 pandocfilters 1.5.1 parso 0.8.4 pathspec 0.12.1 pdf2image 1.17.0 pdfminer.six 20231228 pdfplumber 0.11.4 peewee 3.17.8 peft 0.13.2 pexpect 4.9.0 pi_heif 0.20.0 pikepdf 9.4.1 pillow 10.2.0 pip 24.3.1 platformdirs 4.3.6 portalocker 2.10.1 prometheus_client 0.21.0 prompt_toolkit 3.0.48 propcache 0.2.0 proto-plus 1.25.0 protobuf 4.25.5 psutil 6.1.0 ptyprocess 0.7.0 pure_eval 0.2.3 py7zr 0.22.0 pyarrow 18.0.0 pyasn1 0.6.1 pyasn1_modules 0.4.1 pybcj 1.0.2 pycocotools 2.0.8 pycparser 2.22 pycryptodomex 3.21.0 pydantic 2.9.2 pydantic_core 2.23.4 Pygments 2.18.0 PyGObject 3.42.1 PyJWT 2.3.0 pyparsing 2.4.7 pypdf 5.1.0 pypdfium2 4.30.0 pyppmd 1.1.0 python-apt 2.4.0+ubuntu4 python-dateutil 2.8.2 python-iso639 2024.10.22 python-json-logger 2.0.7 python-magic 0.4.27 python-multipart 0.0.17 pytz 2024.2 PyYAML 6.0.1 pyzmq 24.0.1 pyzstd 0.16.2 RapidFuzz 3.10.1 referencing 0.35.1 regex 2024.11.6 requests 2.32.3 requests-toolbelt 1.0.0 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.5.3 rouge_score 0.1.2 rpds-py 0.21.0 rsa 4.9 ruamel.yaml 0.17.40 ruamel.yaml.clib 0.2.12 safetensors 0.4.5 scikit-learn 1.5.2 scipy 1.14.1 SecretStorage 3.3.1 semgrep 1.96.0 Send2Trash 1.8.3 sentence-transformers 3.3.0 sentencepiece 0.2.0 setuptools 75.4.0 six 1.16.0 sniffio 1.3.1 soupsieve 2.6 stack-data 0.6.3 sympy 1.13.1 tabulate 0.9.0 termcolor 2.5.0 terminado 0.18.1 texttable 1.7.0 threadpoolctl 3.5.0 timm 1.0.11 tinycss2 1.4.0 tokenize_rt 6.1.0 tokenizers 0.20.3 tomli 2.0.2 torch 2.5.1+cu124 torchaudio 2.5.1+cu124 torchvision 0.20.1+cu124 tornado 6.4.1 tqdm 4.67.0 traitlets 5.14.3 transformers 4.46.2 triton 3.1.0 types-python-dateutil 2.9.0.20241003 typing_extensions 4.8.0 typing-inspect 0.9.0 tzdata 2024.2 unstructured 0.15.8 unstructured-client 0.27.0 unstructured-inference 0.7.36 unstructured.pytesseract 0.3.13 uri-template 1.3.0 urllib3 2.2.3 wadllib 1.3.6 wcmatch 8.5.2 wcwidth 0.2.13 webcolors 24.11.1 webencodings 0.5.1 websocket-client 1.8.0 wheel 0.45.0 widgetsnbextension 4.0.13 wrapt 1.16.0 xxhash 3.5.0 yarl 1.17.1 zipp 1.0.0
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Sep_12_02:18:05_PDT_2024 Cuda compilation tools, release 12.6, V12.6.77 Build cuda_12.6.r12.6/compiler.34841621_0
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.6 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA A40 On | 00000000:53:00.0 Off | 0 | | 0% 40C P0 208W / 300W | 12700MiB / 46068MiB | 94% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA A40 On | 00000000:57:00.0 Off | 0 | | 0% 38C P0 194W / 300W | 12700MiB / 46068MiB | 94% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 2 NVIDIA A40 On | 00000000:98:00.0 Off | 0 | | 0% 41C P0 196W / 300W | 12676MiB / 46068MiB | 98% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 3 NVIDIA A40 On | 00000000:CE:00.0 Off | 0 | | 0% 38C P0 193W / 300W | 12700MiB / 46068MiB | 98% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 4 NVIDIA A40 On | 00000000:D1:00.0 Off | 0 | | 0% 41C P0 205W / 300W | 12724MiB / 46068MiB | 99% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 5 NVIDIA A40 On | 00000000:D2:00.0 Off | 0 | | 0% 39C P0 212W / 300W | 12724MiB / 46068MiB | 99% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 6 NVIDIA A40 On | 00000000:D6:00.0 Off | 0 | | 0% 41C P0 213W / 300W | 12654MiB / 46068MiB | 99% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+
I run the following command, here I use a very small dataset just for testing (I tried the default dataset as well)
Using a single GPU, I can increase context_length up to 22000 and it fails with 23000
torchrun --nnodes 1 \
--nproc_per_node 1 \
finetuning.py \
--model_name meta-llama/Llama-3.2-1B \
--context_length 22000 \
--num_epochs 1 \
--dataset "custom_dataset" \
--custom_dataset.file "custom_dataset.py" \
--enable_fsdp \
--use_peft \
--peft_method lora \
--output_dir ./output \
--batch_size_training 1
When I increase nproc_per_node
to 7, I can increase context_length
to to 24000 and it fails with 25000. Is it normal? Why a single GPU can handle 22000 tokens and adding 6 more GPUs only increase the context size by 9%?
torchrun --nnodes 1 \
--nproc_per_node 7 \
finetuning.py \
--model_name meta-llama/Llama-3.2-1B \
--context_length 25000 \
--num_epochs 1 \
--dataset "custom_dataset" \
--custom_dataset.file "custom_dataset.py" \
--enable_fsdp \
--use_peft \
--peft_method lora \
--output_dir ./output \
--batch_size_training 1
W1113 14:36:19.328000 23324 torch/distributed/run.py:793] W1113 14:36:19.328000 23324 torch/distributed/run.py:793] ***************************************** W1113 14:36:19.328000 23324 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1113 14:36:19.328000 23324 torch/distributed/run.py:793] ***************************************** /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( /usr/local/lib/python3.11/dist-packages/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead from torch.distributed._shard.checkpoint import ( Clearing GPU cache for all ranks --> Running with torch dist debug set to detail --> Model meta-llama/Llama-3.2-1B --> meta-llama/Llama-3.2-1B has 1235.8144 Million params trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 bFloat16 enabled for mixed precision - using bfSixteen policy trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 trainable params: 851,968 || all params: 1,236,666,368 || trainable%: 0.0689 --> applying fsdp activation checkpointing... --> Training Set Length = 35 --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> Validation Set Length = 9 Preprocessing dataset: 0%| | 0/35 [00:00, ?it/s]--> applying fsdp activation checkpointing... Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.75it/s] length of dataset_train 37 Preprocessing dataset: 60%|█████████████████████████████████████████████████████████████████████████████████████▊ | 21/35 [00:00<00:00, 27.61it/s]Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.53it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.32it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 28.67it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 29.07it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 26.84it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 0%| | 0/9 [00:00, ?it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 26.84it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.12it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 28.01it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 29.89it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 26.90it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:01<00:00, 26.90it/s] length of dataset_train 37 Can not find the custom data_collator in the dataset.py file (custom_dataset.py). Using the default data_collator instead. --> Num of Training Set Batches loaded = 5 Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.69it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 Preprocessing dataset: 33%|████████████████████████████████████████████████▎ | 3/9 [00:00<00:00, 23.29it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/5 [00:00, ?it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/5 [00:00, ?it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Preprocessing dataset: 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 7/9 [00:00<00:00, 30.95it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Preprocessing dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 26.38it/s] --> Num of Validation Set Batches loaded = 1 --> Num of Validation Set Batches loaded = 1 Starting epoch 0/1 train_config.max_train_step: 0 /usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/5 [00:00, ?it/s]/usr/local/lib/python3.11/dist-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1/1, step 0/5 completed (loss: 0.4009278118610382): 20%|████████████████████ | 1/5 [00:08<00:35, 8.78s/it][rank6]: Traceback (most recent call last): [rank6]: File "/workspace/finetuning.py", line 5, in[rank6]: fire.Fire(main) [rank6]: File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 135, in Fire [rank6]: component_trace = _Fire(component, args, parsed_flag_args, context, name) [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank6]: File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 468, in _Fire [rank6]: component, remaining_args = _CallAndUpdateTrace( [rank6]: ^^^^^^^^^^^^^^^^^^^^ [rank6]: File "/usr/local/lib/python3.11/dist-packages/fire/core.py", line 684, in _CallAndUpdateTrace [rank6]: component = fn(*varargs, **kwargs) [rank6]: ^^^^^^^^^^^^^^^^^^^^^^ [rank6]: File "/usr/local/lib/python3.11/dist-packages/llama_recipes/finetuning.py", line 311, in main [rank6]: results = train( [rank6]: ^^^^^^ [rank6]: File "/usr/local/lib/python3.11/dist-packages/llama_recipes/utils/train_utils.py", line 175, in train [rank6]: loss.backward() [rank6]: File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 581, in backward [rank6]: torch.autograd.backward( [rank6]: File "/usr/local/lib/python3.11/dist-packages/torch/autograd/__init__.py", line 347, in backward [rank6]: _engine_run_backward( [rank6]: File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward [rank6]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.95 GiB. GPU 6 has a total capacity of 44.35 GiB of which 4.98 GiB is free. Process 1595066 has 39.36 GiB memory in use. Of the allocated memory 26.45 GiB is allocated by PyTorch, and 12.44 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) Training Epoch: 1/1, step 0/5 completed (loss: 0.611149251461029): 20%|████████████████████▏ | 1/5 [00:11<00:44, 11.23s/it] W1113 14:36:44.924000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23392 closing signal SIGTERM W1113 14:36:44.925000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23393 closing signal SIGTERM W1113 14:36:44.926000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23394 closing signal SIGTERM W1113 14:36:44.927000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23395 closing signal SIGTERM W1113 14:36:44.928000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23396 closing signal SIGTERM W1113 14:36:44.929000 23324 torch/distributed/elastic/multiprocessing/api.py:897] Sending process 23397 closing signal SIGTERM E1113 14:36:46.111000 23324 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 6 (pid: 23398) of binary: /usr/bin/python Traceback (most recent call last): File "/usr/local/bin/torchrun", line 8, in sys.exit(main()) ^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 919, in main run(args) File "/usr/local/lib/python3.11/dist-packages/torch/distributed/run.py", line 910, in run elastic_launch( File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 138, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/torch/distributed/launcher/api.py", line 269, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ finetuning.py FAILED ------------------------------------------------------------ Failures: ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2024-11-13_14:36:44 host : a7552316fef1 rank : 6 (local_rank: 6) exitcode : 1 (pid: 23398) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================
Hi! FSDP is mostly used when a single model can not fit into a single GPU, and it will not increase the training context length. There is a Context Parallel feature that has been implemented by Torchtitan where the context would be processed/parallelized and adding more GPU's would make an improvement. Hopefully, this can help with your problem.
Hi @Amerehei, With FSDP, you are not sharding the growing activations from the increased context length so as you increase the context length, your activations will continue to balloon and the limit will continue to be what 1 gpu can handle. You are getting a tiny gain by sharding across more gpus b/c - the amount each GPU holds for the model weights/optimizers/gradients will reduce a bit more, but that only creates a slightly bit more room to hold the larger unsharded sequence activations..as you saw when adding 7 more gpus only increased your OOM limit by 9%.
For longer context windows, you need to add in another dimension of sharding which in this case is what @wukaixingxp is noting above, context parallel.
This sharding would then allow you to train with larger and larger context windows as the sequence (context) activations are being sharded across gpus.
Titan is focused on pre-training though and here you are doing fine tuning, so we'd need to see about adding context parallel here in llama-recipes as the fix to allow you to expand to larger contexts by adding more gpus to the world size.
For reference, context parallel is a new feature that was just landed a few weeks ago in Titan, but this chart shows you it's impact on supporting longer sequence lengths:
@lessw2020 @wukaixingxp Thanks for the information
When I increase
--context_length
no matter how many GPU is available, it ends with OutOfMemoryError errorI'm wondering why fully sharded FSDP doesn't work as expected
How can I train a model with more than 50K context length