axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.92k stars 872 forks source link

Deepspeed zero3 + LoRA: RuntimeError: Only Tensors of floating point and complex dtype can require gradients #2068

Open bursteratom opened 2 days ago

bursteratom commented 2 days ago

Please check that this issue hasn't been reported before.

Expected Behavior

I expect deepspeed zero3 and 8-bit LoRA to be compatible and runs without error

Current behaviour

When loading model with deepspeed zero3 and 8-bit LoRA enabled, I ran into the error RuntimeError: Only Tensors of floating point and complex dtype can require gradients :

Image 11-15-24 at 8 58 PM

However, if you use zero3 in tandem with 4-bit qLoRA, or just do full fine-tuning with zero3 enabled, it works fine.

Steps to reproduce

  1. Set up LoRA
  2. enable deepspeed zero3
  3. RuntimeError: Only Tensors of floating point and complex dtype can require gradients

Config yaml

base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
  - path: mhenrichsen/alpaca_2k_test
    type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/lora-out

sequence_len: 4096
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
  - embed_tokens
  - lm_head

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
   pad_token: <|end_of_text|>

Possible solution

(Putting on a tinfoil hat) I think it's a bug within axolotl code base as opposed to some deeper issue with deepspeed zero3, seeing as it works with qlora.

Which Operating Systems are you using?

Python Version

3.11.10

axolotl branch-commit

main

Acknowledgements

bursteratom commented 1 day ago

this is being worked on via PR#1852 and https://github.com/huggingface/transformers/pull/32943