meta-llama / llama-recipes

Scripts for fine-tuning Meta Llama with composable FSDP & PEFT methods to cover single/multi-node GPUs. Supports default & custom datasets for applications such as summarization and Q&A. Supporting a number of candid inference solutions such as HF TGI, VLLM for local or cloud deployment. Demo apps to showcase Meta Llama for WhatsApp & Messenger.
15.45k stars 2.23k forks source link

llama finetune.py throws pytorch tensor datatype error with 4 bit quantization #675

Open AAndersn opened 2 months ago

AAndersn commented 2 months ago

System Info

PyTorch 2.4.0, Cuda 12.1, CentOS HPC cluster with 7x H100 GPUs

Information

🐛 Describe the bug

FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 python -m torch.distributed.launch \
    --nnodes 1 \
    --nproc_per_node 5 \
    -m llama_recipes.finetuning \
    --enable_fsdp \
    --model_name meta-llama/Meta-Llama-3.1-70B \
    --quantization 4bit \
    --use_peft \
    --peft_method lora \
    --dataset grammar_dataset \
    --lr 5e-5 \
    --save_model \
    --use_wandb \
    --output_dir /qfs/people/usr/models/70B

Error logs

Loading checkpoint shards:   0%|                                                                                                                                    | 0/4 [00:02<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank1]:     return _run_code(code, main_globals, None,
[rank1]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
[rank1]:     exec(code, run_globals)
[rank1]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 291, in <module>
[rank1]:     fire.Fire(main)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 121, in main
[rank1]:     model = LlamaForCausalLM.from_pretrained(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3960, in from_pretrained
[rank1]:     ) = cls._load_pretrained_model(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4434, in _load_pretrained_model
[rank1]:     new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 970, in _load_state_dict_into_meta_model
[rank1]:     value = type(value)(value.data.to("cpu"), **value.__dict__)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 149, in __new__
[rank1]:     self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank1]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
Loading checkpoint shards:   0%|                                                                                                                                    | 0/4 [00:00<?, ?it/s]W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60528 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60529 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60530 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60532 closing signal SIGTERM
E0922 19:40:05.857000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 3 (pid: 60531) of binary: /qfs/people/usr/venv_llama_2/bin/python

This error message is then repeated by each separate GPU process, followed by

Traceback (most recent call last):
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 208, in <module>
    main()
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/typing_extensions.py", line 2360, in wrapper
    return arg(*args, **kwargs)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 204, in main
    launch(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in launch
    run(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
llama_recipes.finetuning FAILED

If the command is run without the FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 header, then it throws a different error:

ValueError: Cannot flatten integer dtype tensors
[rank0]: Traceback (most recent call last):
[rank0]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 291, in <module>
[rank0]:     fire.Fire(main)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 179, in main
[rank0]:     model = FSDP(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
[rank0]:     _auto_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
[rank0]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   [Previous line repeated 2 more times]
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
[rank0]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
[rank0]:     return wrapper_cls(module, **kwargs)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank0]:     _init_param_handle_from_module(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
[rank0]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
[rank0]:     handle = FlatParamHandle(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
[rank0]:     self._init_flat_param_and_metadata(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
[rank0]:     ) = self._validate_tensors_to_flatten(params)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 769, in _validate_tensors_to_flatten
[rank0]:     raise ValueError("Cannot flatten integer dtype tensors")
[rank0]: ValueError: Cannot flatten integer dtype tensors

E0923 09:17:49.746000 47893711004800 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 44819) of binary: /qfs/people/usr/venv_llama_2/bin/python
Traceback (most recent call last):
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 208, in <module>
    main()
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/typing_extensions.py", line 2360, in wrapper
    return arg(*args, **kwargs)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 204, in main
    launch(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in launch
    run(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
llama_recipes.finetuning FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-23_09:17:48
  host      : h100-02.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 44819)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected behavior

This call and dataset work fine for llama3.1-8B without quantization, but fail with 4-bit quantization. The int4 parameter specific given in https://github.com/meta-llama/llama-recipes/blob/main/recipes/quickstart/finetuning/multigpu_finetuning.md#with-fsdp--qlora does not exist.

mreso commented 2 months ago

Hi @AAndersn thanks for reporting. I was not able to repro this do far but I will give it another try later today. You're right about the int4, this is a left over from a back and forth while we created the PR for QLORA. Would you be interested in creating a PR to fix this?

AAndersn commented 2 months ago

@mreso Happy to make a PR to update the docs. I'll also try rolling back to an older version of PyTorch and update this issue tomorrow to see if that fixes it.

AAndersn commented 2 months ago

The problem appears to be an issue with AutoModel.from_pretrained() inside the finetuning.py script.

I rebuilt my environment today with llama-recipes 0.0.4 and transformers 4.45.0 and am able to run this snippet successfully:

import torch
from transformers import BitsAndBytesConfig, AutoModel

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16
)

model = AutoModel.from_pretrained(
            "meta-llama/Meta-Llama-3.1-8B",
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.bfloat16
)

However, if I copy and paste this exact snippet into finetuning.py, the AutoModel call fails with same message

python3.11/site-packages/bitsandbytes/nn/modules.py", line 149, in __new__
[rank3]:     self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
wukaixingxp commented 2 months ago

Hi! @AAndersn Thanks for reporting this. I am wondering if changing AutoModel to LlamaForCausalLM solve this? Can you try? Thanks!

AAndersn commented 2 months ago

@wukaixingxp - Thank you so much! Changing AutoModel to LLamaForCausalLM fixed it! Testing now with 8B and 70B.

If that works, I will install the pytest suite and then update #681 to include this fix

wukaixingxp commented 2 months ago

I tried your command with transformers = 4.45.0 and torch = 2.4.1. But I got this error```rank2: Traceback (most recent call last): rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/runpy.py", line 196, in _run_module_as_main rank2: return _run_code(code, main_globals, None, rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/runpy.py", line 86, in _run_code rank2: exec(code, run_globals) rank2: File "/home/kaiwu/work/llama-recipes/src/llama_recipes/finetuning.py", line 332, in

rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 143, in Fire rank2: component_trace = _Fire(component, args, parsed_flag_args, context, name) rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire rank2: component, remaining_args = _CallAndUpdateTrace( rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace rank2: component = fn(*varargs, **kwargs) rank2: File "/home/kaiwu/work/llama-recipes/src/llama_recipes/finetuning.py", line 203, in main rank2: model = FSDP( rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in init

rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap rank2: _recursive_wrap(recursive_wrap_kwargs, root_kwargs) # type: ignorearg-type: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap rank2: wrapped_child, num_wrapped_params = _recursive_wrap( rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap rank2: wrapped_child, num_wrapped_params = _recursive_wrap( rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap rank2: wrapped_child, num_wrapped_params = _recursive_wrap( rank2: Previous line repeated 2 more times: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap rank2: return _wrap(module, wrapper_cls, kwargs), nonwrapped_numel rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap rank2: return wrapper_cls(module, kwargs) rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in init

rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 565, in _init_param_handle_from_module

rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 897, in _materialize_meta_module rank2: raise e rank2: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 890, in _materialize_meta_module rank2: module.reset_parameters() # type: ignoreoperator: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1729, in getattr rank2: raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") rank2: AttributeError: 'LlamaRMSNorm' object has no attribute 'reset_parameters'. Did you mean: 'get_parameter'? /home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py:892: UserWarning: Unable to call reset_parameters() for module on meta device with error 'LlamaRMSNorm' object has no attribute 'reset_parameters'. Please ensure that your module oftype <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> implements a reset_parameters() method.```

AAndersn commented 2 months ago

pip reported a conflict with torch = 2.4.1.

I was able to run 8B with 4bit quantization with torch = 2.4.0 by replacing the AutoModel with LlamaForCausalLM or LlamaForQuestionAnswering (for use with a custom dataset).

AAndersn commented 2 months ago

@wukaixingxp - I see you have made that update in https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/finetuning.py#L139, so will close this issue as fixed by #686

Thanks so much for your help!

AAndersn commented 1 week ago

@wukaixingxp -- Today I pulled the latest llama-recipes (0.0.4.post) and am getting the same Only Tensors of floating point and complex dtype can require gradients error again. I tried to bump transformers to 4.46.3 and pytorch to 2.5.1, but still no luck. If you can please try running my bash call for grammar or similar csv dataset when you have free time, I would appreciate it a lot.

FYI - I am still able to run an archived copy of your fixes from https://github.com/meta-llama/llama-recipes/commit/9c7a5b421f20b73511d9d0d49078824393e63faa with transformers 4.45.0 and PyTorch 2.4.0, so it's not stopping me for the time being.