axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
8.01k stars 885 forks source link

mistral fsdpa qlora crashes (cu_seqlens) #1386

Open lucyknada opened 8 months ago

lucyknada commented 8 months ago

Please check that this issue hasn't been reported before.

Current behaviour

crashes with:

TypeError: MistralSdpaAttention.forward() got an unexpected keyword argument 'cu_seqlens'

Steps to reproduce

docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ./cache:/cache/huggingface winglian/axolotl:main-py3.11-cu121-2.1.2

run that docker, modify the mistral qlora example

fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_offload_params: true
  # fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer

(no matter the options, it'll still crash)

Config yaml

base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

eval_sample_packing: False
datasets:
  - path: /workspace/axolotl/xxx.json
    type: completion
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
  - full_shard
  # - auto_wrap
fsdp_config:
  fsdp_offload_params: true
  # fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10, 3.11

axolotl branch-commit

9b6ee83a73d5ffbdc33cfb383a131a08c2b594ff

Acknowledgements

winglian commented 8 months ago

Can you verify that flash attention is installed?

lucyknada commented 8 months ago
# pip show flash-attn
Name: flash-attn
Version: 2.5.5
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: trid@cs.stanford.edu
License: 
Location: /root/miniconda3/envs/py3.11/lib/python3.11/site-packages
Requires: einops, ninja, packaging, torch
Required-by:

seems like it (inside the docker image)