axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.54k stars 816 forks source link

lora example not working with deepspeed zero3 #1481

Open xu3kev opened 5 months ago

xu3kev commented 5 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

should be able to do training as usual

Current behaviour

crash with the following error message

Traceback (most recent call last):
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/home/wl678/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/home/wl678/axolotl/src/axolotl/train.py", line 84, in train
    model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
  File "/home/wl678/axolotl/src/axolotl/utils/models.py", line 608, in load_model
    raise err
  File "/home/wl678/axolotl/src/axolotl/utils/models.py", line 527, in load_model
    model = LlamaForCausalLM.from_pretrained(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3562, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3989, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/transformers/modeling_utils.py", line 822, in _load_state_dict_into_meta_model
    value = type(value)(value.data.to("cpu"), **value.__dict__)
  File "/home/wl678/.conda/envs/axolotl4/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 491, in __new__
    obj = torch.Tensor._make_subclass(cls, data, requires_grad)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients

Steps to reproduce

run the codellama-7b lora example with deepspeed zero3

accelerate launch -m axolotl.cli.train examples/code-llama/7b/lora.yml --deepspeed deepspeed_configs/zero3.json --eval_sample_packing false

Config yaml

base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10

axolotl branch-commit

main/c2b64e4dcff59cfbd754626e5172688433cc13e1

Acknowledgements

1716649290 commented 5 months ago

I met the same problem. Have you tried any good solutions since then?

watsonchua commented 3 months ago

I'm facing this issue now. It has been two months since this was reported. Did anybody find a solution?

kalomaze commented 3 months ago

Also running into this

Nero10578 commented 2 months ago

Ran into this and setting load_in_8bit to false made it work.

cjakfskvnad commented 2 weeks ago

same here.