axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
8.03k stars 887 forks source link

OPRO, DPO don't work with Mixtral-8x22B. FSDP + QLORA & bigstral-ds-zero3 #1534

Open 0-hero opened 7 months ago

0-hero commented 7 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

DPO/ORPO training should run successfully

Current behaviour

Models tested (don't have any issues with inference)

Machines tested (tried each type from multiple providers)

Images tested

Ran tests with all the combinations mentioned above

Both issues mentioned below happen for both ORPO & DPO

Issue 1 - FSDP + QLORA

https://github.com/OpenAccess-AI-Collective/axolotl/issues/1494

Issue 2 - bigstral-ds-zero3

Happened anytime before the first 20 steps. Tried reducing the below to 1 but the issue persists.

gradient_accumulation_steps: 1
micro_batch_size: 1

Training hangs and eventually stops with a NCCL timeout https://github.com/huggingface/accelerate/issues/314 GPU util also falls once it hangs, example below

Screenshot 2024-04-18 at 9 44 07 AM

Steps to reproduce

Start training with any of the configs below

FSDP + QLORA config

base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

rl: dpo
datasets:
  - path: argilla/ultrafeedback-binarized-preferences-cleaned
    split: train
    type: chatml.ultra

dpo_beta: 0.1

chat_template: chatml
default_system_message: You are a helpful assistant

dataset_prepared_path: data
val_set_size: 0
output_dir: output

sequence_len: 8192
sample_packing: false
pad_to_sequence_len: false

adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head

gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
   - full_shard
   - auto_wrap
fsdp_config:
 fsdp_limit_all_gathers: true
 fsdp_sync_module_states: true
 fsdp_offload_params: true
 fsdp_use_orig_params: false
 fsdp_cpu_ram_efficient_loading: true
 fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
 fsdp_state_dict_type: FULL_STATE_DICT
 fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP

special_tokens:
  bos_token: "<s>"
  eos_token: "<|im_end|>"
  unk_token: "<unk>"
tokens:
  - "<|begin_func|>"
  - "<|end_func|>"
  - "<|begin_func_response|>"
  - "<|end_func_response|>"
  - "<|im_start|>"
  - "<|im_end|>"

bigstral-ds-zero3 config

base_model: 0-hero/Matter-0.2-8x22B
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

unfrozen_parameters:
  - ^lm_head.weight$
  - ^model.embed_tokens.weight$
  - model.layers.4[4-9]+.block_sparse_moe.gate
  - model.layers.4[4-9]+.block_sparse_moe.experts
  - model.layers.5[0-5]+.block_sparse_moe.gate
  - model.layers.5[0-5]+.block_sparse_moe.experts

model_config:
  output_router_logits: true

rl: orpo
datasets:
  - path: mlabonne/orpo-mix-40k
    split: train
    type: orpo.chat_template

chat_template: chatml
default_system_message: You are a helpful assistant

dataset_prepared_path: data
val_set_size: 0
output_dir: output

sequence_len: 8192
sample_packing: false
pad_to_sequence_len: false

gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "<|im_end|>"
  unk_token: "<unk>"
tokens:
  - "<|begin_func|>"
  - "<|end_func|>"
  - "<|begin_func_response|>"
  - "<|end_func_response|>"
  - "<|im_start|>"
  - "<|im_end|>"

Config yaml

No response

Possible solution

No response

Which Operating Systems are you using?

Python Version

3.10,3.11

axolotl branch-commit

main/0eadfc8

Acknowledgements

0-hero commented 7 months ago

@winglian raised as new issue as mentioned in the other discussion