Open PanagiotisFytas opened 3 weeks ago
While running:
torchrun --nnodes 1 --nproc_per_node 2 recipes/quickstart/finetuning/finetuning.py \ --use_peft \ --peft_method lora \ --model_name 'meta-llama/Llama-3.1-70B-Instruct' \ --output_dir './my_lora_weights/70B' \ --batch_size_training 1 \ --batching_strategy "padding" \ --weight_decay 0.2 \ --num_epochs 10 \ --dataset custom_dataset --quantization '4bit' \ --enable_fsdp True --use_fast_kernels True
The code that leads to the error is from llama-recipes (https://github.com/meta-llama/llama-recipes/blob/98707b72fda091b2b20e3ab2ffaf9a86e4fccd84/src/llama_recipes/model_checkpointing/checkpoint_handler.py#L273):
def save_peft_checkpoint(model, model_path): """save_pretrained peft model""" options = StateDictOptions(full_state_dict=True, cpu_offload=True) if isinstance(model, FSDP): state_dict = get_model_state_dict(model, options=options) model.save_pretrained(model_path, state_dict=state_dict) else: model.save_pretrained(model_path)
... rank1: File "/home/Documents/llama-recipes/src/llama_recipes/utils/train_utils.py", line 259, in train rank1: save_peft_checkpoint(model, train_config.output_dir) rank1: File "/home/Documents/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py", line 276, in save_peft_checkpoint rank1: state_dict = get_model_state_dict(model, options=options) rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 995, in get_model_state_dict rank1: model_state_dict = _get_model_state_dict(model, info) rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context rank1: return func(*args, **kwargs) rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 475, in _get_model_state_dict rank1: fqns = _get_fqns(model, key) rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 224, in _get_fqns rank1: curr_obj = getattr(curr_obj, curr_obj_name)
Apparently as per https://github.com/meta-llama/llama-recipes/issues/674 a temporary fix is making cpu_offload=False but this is only a bandaid fix that disables CPU offloading
I have the same issue with Llama 70B FSDP QLora training. When cpu offloading is turned off, it runs out of memory (as one might expect).
System Info
Reproduction
While running:
The code that leads to the error is from llama-recipes (https://github.com/meta-llama/llama-recipes/blob/98707b72fda091b2b20e3ab2ffaf9a86e4fccd84/src/llama_recipes/model_checkpointing/checkpoint_handler.py#L273):
Expected behavior
... rank1: File "/home/Documents/llama-recipes/src/llama_recipes/utils/train_utils.py", line 259, in train
rank1: save_peft_checkpoint(model, train_config.output_dir)
rank1: File "/home/Documents/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py", line 276, in save_peft_checkpoint
rank1: state_dict = get_model_state_dict(model, options=options)
rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 995, in get_model_state_dict
rank1: model_state_dict = _get_model_state_dict(model, info)
rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
rank1: return func(*args, **kwargs)
rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 475, in _get_model_state_dict rank1: fqns = _get_fqns(model, key) rank1: File "/home/miniconda3/envs/llama_recipes_new/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict.py", line 224, in _get_fqns rank1: curr_obj = getattr(curr_obj, curr_obj_name)
Apparently as per https://github.com/meta-llama/llama-recipes/issues/674 a temporary fix is making cpu_offload=False but this is only a bandaid fix that disables CPU offloading