meta-llama / llama-recipes

Scripts for fine-tuning Meta Llama with composable FSDP & PEFT methods to cover single/multi-node GPUs. Supports default & custom datasets for applications such as summarization and Q&A. Supporting a number of candid inference solutions such as HF TGI, VLLM for local or cloud deployment. Demo apps to showcase Meta Llama for WhatsApp & Messenger.
15.29k stars 2.21k forks source link

add freeze_LLM_only option for mllama finetuning #791

Closed JimChienTW closed 2 days ago

JimChienTW commented 6 days ago

What does this PR do?

Fixes #770

Feature/Issue Validation/Testing

To follow the training settings in the original paper, as mentioned in issue #770, I added a new function to tune the vision encoder, projector, and cross-attention layers inside the LLM. By setting train_config.freeze_LLM_only to True, you can enable this functionality.

I conducted two tests:

  1. Using test_finetuning.py.
  2. Running the finetuning script finetuning.py directly.

Both tests passed successfully. In detail, I ran the finetuning process on 8×H100 GPUs. The process was smooth, as shown below.

src/tests/test_finetuning.py ...................... [100%]

================================================================ warnings summary ================================================================= ../../llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17 /media/Pluto/jim/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead from torch.distributed._shard.checkpoint import (

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================== 22 passed, 1 warning in 3.76s ==========================================================


- [x] torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding —freeze_LLM_only True

W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] in oss file in oss file in oss file in oss file in oss file in oss file in oss file in oss file Clearing GPU cache for all ranks --> Running with torch dist debug set to detail Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 10.94it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 12.16it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 9.53it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 7.23it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 11.39it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 7.16it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 7.32it/s] Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 8.04it/s] bFloat16 enabled for mixed precision - using bfSixteen policy --> applying fsdp activation checkpointing... --> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 1333.777617 Million params

README.md: 100%|██████████████████████████████████████████████████| 50.3k/50.3k [00:00<00:00, 1.11MB/s] --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... --> applying fsdp activation checkpointing... (…)-00000-of-00011-f83c2bdf2cf711bf.parquet: 100%|██████████████████| 540M/540M [00:12<00:00, 42.0MB/s] (…)-00001-of-00011-fef40eeeea84a563.parquet: 100%|██████████████████| 580M/580M [00:13<00:00, 42.1MB/s] (…)-00002-of-00011-c0733bedbcc41420.parquet: 100%|██████████████████| 541M/541M [00:12<00:00, 42.3MB/s] (…)-00003-of-00011-fee117dc7680fb5f.parquet: 100%|██████████████████| 577M/577M [00:13<00:00, 41.2MB/s] (…)-00004-of-00011-c01c965b3ac5c2c0.parquet: 47%|████████▍ | 273M/581M 00:06<00:07, 42.7MB/s-00004-of-00011-c01c965b3ac5c2c0.parquet: 63%|███████████▎ | 367M/581M 00:08<00:04, 47.1MB/s-00004-of-00011-c01c965b3ac5c2c0.parquet: 100%|██████████████████| 581M/581M [00:13<00:00, 42.6MB/s] (…)-00005-of-00011-7eb79ee48c0c4065.parquet: 100%|██████████████████| 527M/527M [00:12<00:00, 42.6MB/s] (…)-00006-of-00011-4a139e7c78fb5e47.parquet: 100%|██████████████████| 519M/519M [00:12<00:00, 41.5MB/s] (…)-00007-of-00011-8f649db4d5664766.parquet: 100%|██████████████████| 559M/559M [00:24<00:00, 22.5MB/s] (…)-00008-of-00011-23185b703995741f.parquet: 100%|██████████████████| 555M/555M [00:13<00:00, 42.6MB/s] (…)-00009-of-00011-b0bb42debccbf310.parquet: 100%|██████████████████| 519M/519M [00:22<00:00, 22.7MB/s] (…)-00010-of-00011-74ed380c1a2c83aa.parquet: 100%|██████████████████| 579M/579M [00:14<00:00, 41.0MB/s] Generating train split: 100%|████████████████████████| 165746/165746 [00:06<00:00, 27228.77 examples/s] --> Training Set Length = 1800 --> Validation Set Length = 200 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 --> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 --> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 --> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 --> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 length of dataset_train 1800 custom_data_collator is used --> Num of Training Set Batches loaded = 112 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 --> Num of Validation Set Batches loaded = 25 --> Num of Validation Set Batches loaded = 25 Starting epoch 0/3 train_config.max_train_step: 0 /usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats. warnings.warn( Training Epoch: 1: 0%| | 0/112 [00:00<?, ?it/s]use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False. Training Epoch: 1/3, step 19/112 completed (loss: 0.032936301082372665): 18%|▏| 20/112 [00:48<02:44, Training Epoch: 1/3, step 20/112 completed (loss: 0.03712736815214157): 19%|▏| 21/112 [00:50<02:42, 1Training Epoch: 1/3, step 22/112 completed (loss: 0.11487767100334167): 21%|▏| 23/112 [00:53<02:38,



## Before submitting
- [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the [contributor guideline](https://github.com/facebookresearch/llama-recipes/blob/main/CONTRIBUTING.md#pull-requests),
      Pull Request section?
- [x] Was this discussed/approved via a Github issue? Please add a link
      to it if that's the case.
- [x] Did you make sure to update the documentation with your changes?  
- [ ] Did you write any new necessary tests?

Thanks for contributing 🎉!
JimChienTW commented 6 days ago

This is my first time contributing to open source, and I’d really appreciate any feedback or advice you can share!

init27 commented 5 days ago

@JimChienTW Really appreciate you contributing to our repository and congrats on your first contribution, we will review your PR this week.

Thanks again!

wukaixingxp commented 3 days ago

@JimChienTW Thanks for your PR, but I wonder why my freeze_LLM has 709.622115 Million trainable params for 11B, and without freeze_LLM 2667.555217 Million params, but it should be Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params, can you double check if there is something wrong? Please see the logs below:

with freeze_LLM log:

~/work/to_merge/llama-recipes (add_vision_finetuning_features)]$ torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding --freeze_LLM_only True
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] 
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] *****************************************
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.64it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.59it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.63it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.62it/s]
bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 709.622115 Million params

--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training Epoch: 1/3, step 49/225 completed (loss: 0.1835767775774002):  ^CW1118 13:44:23.469000 140448872657920 torch/distributed/elastic/agent/server/api.py:688] Received Signals.SIGINT death signal, shutting down workers

and without freeze_LLM:


~/work/to_merge/llama-recipes (add_vision_finetuning_features)]$ torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] *****************************************
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.27it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  7.22it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.84it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.51it/s]
bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 2667.555217 Million params

--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training Epoch: 1/3, step 9/225 completed (loss: 0.24111326038837433)```
wukaixingxp commented 3 days ago

run with latest main:

torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] *****************************************
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.72it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.82it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  7.02it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.56it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
JimChienTW commented 3 days ago

Thank you for your review. I found the error was caused by printing model parameters after FSDP. Problem solved.

with freeze_LLM log:

torchrun --nnodes 1 --nproc_per_node 2  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding --freeze_LLM_only True
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] 
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] *****************************************
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] *****************************************
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 12.54it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 12.68it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

After freezing the model:
--> meta-llama/Llama-3.2-11B-Vision-Instruct has 2639.926819 Million trainable params

--> Model state after freezing:
    vision_model: Unfrozen
    language_model: Mixed
    multi_modal_projector: Unfrozen

bFloat16 enabled for mixed precision - using bfSixteen policy

and without freeze_LLM:

W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] 
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] *****************************************
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] *****************************************
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.18it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.68it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy