axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.58k stars 822 forks source link

8-Bit DoRA training with FSDP doesn't work, but 4-bit QDoRA does / peft_use_dora is ignored? #1589

Open kalomaze opened 4 months ago

kalomaze commented 4 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

With 8-bit LoRA enabled and peft_use_dora: true, the training run should function (as it does when QLoRA is used.)

Current behaviour

With this config:

load_in_8bit: true
load_in_4bit: false

and

adapter: lora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.0
lora_target_linear: true
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
peft_use_dora: true
lora_model_dir:

DoRA training appears to fail.

When I use:

load_in_8bit: false
load_in_4bit: true

And train with peft_use_dora: true, the loss curves seem identical to regular 4-bit QLoRA. This leads me to the suspicion that it is still using regular qlora and that DoRA is not functioning, especially considering this blog post image

This is the full error log trace:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:14<00:00,  3.57s/it]
[2024-05-04 04:55:09,550] [INFO] [axolotl.load_model:720] [PID:1672944] [RANK:0] GPU memory usage after model load: 8.588GB (+0.012GB cache, +2.948GB misc)
[2024-05-04 04:55:09,551] [INFO] [axolotl.load_model:771] [PID:1672944] [RANK:0] converting PEFT model w/ prepare_model_for_kbit_training
[2024-05-04 04:55:09,558] [INFO] [axolotl.load_model:780] [PID:1672944] [RANK:0] converting modules to torch.bfloat16 for flash attention
[2024-05-04 04:55:09,561] [INFO] [axolotl.load_lora:924] [PID:1672944] [RANK:0] found linear modules: ['down_proj', 'v_proj', 'k_proj', 'up_proj', 'gate_proj', 'o_proj', 'q_proj']
[2024-05-04 04:55:32,807] [INFO] [axolotl.load_model:825] [PID:1672946] [RANK:2] GPU memory usage after adapters: 15.434GB (+1.253GB cache, +2.731GB misc)
[2024-05-04 04:55:33,354] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672946] [RANK:2] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:33,354] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672946] [RANK:2] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:33,355] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672946] [RANK:2] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:33,355] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672946] [RANK:2] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/root/axo_clone/axolotl/src/axolotl/train.py", line 163, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1837, in train
    return inner_training_loop(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1980, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
    result = tuple(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in __init__
    _auto_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in __init__
    self._init_flat_param_and_metadata(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 759, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
[2024-05-04 04:55:33,547] [INFO] [axolotl.load_model:825] [PID:1672945] [RANK:1] GPU memory usage after adapters: 15.434GB (+1.253GB cache, +2.731GB misc)
[2024-05-04 04:55:34,084] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672945] [RANK:1] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:34,085] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672945] [RANK:1] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:34,085] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672945] [RANK:1] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:34,086] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672945] [RANK:1] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/root/axo_clone/axolotl/src/axolotl/train.py", line 163, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1837, in train
    return inner_training_loop(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1980, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
    result = tuple(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in __init__
    _auto_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in __init__
    self._init_flat_param_and_metadata(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 759, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
[2024-05-04 04:55:35,150] [INFO] [axolotl.load_model:825] [PID:1672947] [RANK:3] GPU memory usage after adapters: 15.434GB (+1.253GB cache, +2.591GB misc)
[2024-05-04 04:55:35,690] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672947] [RANK:3] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:35,691] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672947] [RANK:3] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:35,691] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672947] [RANK:3] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:35,692] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672947] [RANK:3] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/root/axo_clone/axolotl/src/axolotl/train.py", line 163, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1837, in train
    return inner_training_loop(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1980, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
    result = tuple(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in __init__
    _auto_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in __init__
    self._init_flat_param_and_metadata(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 759, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
trainable params: 85,262,336 || all params: 8,115,523,584 || trainable%: 1.050607950522099
[2024-05-04 04:55:35,993] [INFO] [axolotl.load_model:825] [PID:1672944] [RANK:0] GPU memory usage after adapters: 15.434GB (+1.253GB cache, +2.966GB misc)
[2024-05-04 04:55:36,010] [WARNING] [axolotl.utils.freeze.freeze_layers_except:68] [PID:1672944] [RANK:0] All parameters are frozen. Model will not be trained.
[2024-05-04 04:55:36,096] [INFO] [axolotl.train.log:61] [PID:1672944] [RANK:0] Pre-saving adapter config to ./lora-out
[2024-05-04 04:55:36,267] [INFO] [axolotl.train.log:61] [PID:1672944] [RANK:0] Starting trainer...
[2024-05-04 04:55:36,556] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672944] [RANK:0] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:36,557] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672944] [RANK:0] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:36,557] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672944] [RANK:0] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
[2024-05-04 04:55:36,557] [INFO] [axolotl.utils.samplers.multipack._len_est:184] [PID:1672944] [RANK:0] packing_efficiency_estimate: 0.84 total_num_tokens per device: 249797
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 59, in <module>
    fire.Fire(do_cli)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 35, in do_cli
    return do_train(parsed_cfg, parsed_cli_args)
  File "/root/axo_clone/axolotl/src/axolotl/cli/train.py", line 55, in do_train
    return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
  File "/root/axo_clone/axolotl/src/axolotl/train.py", line 163, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1837, in train
    return inner_training_loop(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1980, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1263, in prepare
    result = tuple(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1264, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1140, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1422, in prepare_model
    model = FSDP(model, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 477, in __init__
    _auto_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 590, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 602, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 573, in __init__
    self._init_flat_param_and_metadata(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 623, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 759, in _validate_tensors_to_flatten
    raise ValueError("Cannot flatten integer dtype tensors")
ValueError: Cannot flatten integer dtype tensors
[2024-05-04 04:55:41,703] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1672944) of binary: /root/axo_clone/venv/bin/python3
Traceback (most recent call last):
  File "/root/axo_clone/venv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 46, in main
    args.func(args)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1044, in launch_command
    multi_gpu_launcher(args)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
    distrib_run.run(args)
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/axo_clone/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
axolotl.cli.train FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-05-04_04:55:41
  host      : 2cca8ebfb513
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 1672945)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2024-05-04_04:55:41
  host      : 2cca8ebfb513
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 1672946)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2024-05-04_04:55:41
  host      : 2cca8ebfb513
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 1672947)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-05-04_04:55:41
  host      : 2cca8ebfb513
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1672944)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Steps to reproduce

Config yaml

base_model: Undi95/Meta-Llama-3-8B-hf
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
  - ^lm_head.weight$
  - ^model.embed_tokens.weight$[:128256]
  - model.layers.[0-9]+.(?!gate).*

datasets:
  - path: ./data/sort_test
    data_files:
      - ./data/sort_test/group_0.jsonl
      - ./data/sort_test/group_1.jsonl
      - ./data/sort_test/group_2.jsonl
      - ./data/sort_test/group_3.jsonl
      - ./data/sort_test/group_4.jsonl
      - ./data/sort_test/group_5.jsonl
      - ./data/sort_test/group_6.jsonl
      - ./data/sort_test/group_7.jsonl
      - ./data/sort_test/group_8.jsonl
      - ./data/sort_test/group_9.jsonl
      - ./data/sort_test/group_10.jsonl
      - ./data/sort_test/group_11.jsonl
      - ./data/sort_test/group_12.jsonl
      - ./data/sort_test/group_13.jsonl
      - ./data/sort_test/group_14.jsonl
      - ./data/sort_test/group_15.jsonl
    type: completion
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out

adapter: qlora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.0
lora_target_linear: true
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
peft_use_dora: true
lora_model_dir:

shuffle_merged_datasets: false

sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true

wandb_project: 8b-test
wandb_entity:
wandb_watch:
wandb_name: generic_test_8b_again2
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 0
evals_per_epoch: 1
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
weight_decay: 0.0

special_tokens:
   pad_token: <|end_of_text|>


### Possible solution

_No response_

### Which Operating Systems are you using?

- [X] Linux
- [ ] macOS
- [ ] Windows

### Python Version

3.10.12

### axolotl branch-commit

601c08b4c2d9a8198527e5e33536d3ad499305f0

### Acknowledgements

- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
winglian commented 4 months ago

This is a known issue. It doesn't work bc the underlying peft library tries to dequantize the weights to calculate the norm and scaling after it's been sharded.

The workaround would to be to implement it completely manually unfortunately.

RicardoDominguez commented 4 months ago

I have also observed that when peft_use_dora: true, the loss curves seem identical to regular 4-bit QLoRA.

winglian commented 4 months ago

I did a test run and was able to confirm that the PEFT dora_init method is called on the Lora adapter when peft_use_dora: true is set and that the loss curve is practically identical to regular 8-bit lora.