axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.29k stars 785 forks source link

User-defined dataset config for DPO training doesn't fetch the custom fields #1417

Open abhinand5 opened 5 months ago

abhinand5 commented 5 months ago

Please check that this issue hasn't been reported before.

Expected Behavior

When defining a dataset like this, ideally axolotl should apply the custom config.

rl: dpo
datasets:
  - path: argilla/distilabel-intel-orca-dpo-pairs
    split: train
    type:
      field_system: system
      field_prompt: input
      field_chosen: chosen
      field_rejected: rejected
      prompt_format: "<bos>{system}\n\n### Instruction:\n{prompt}\n### Response:\n"
      chosen_format: "{chosen}<eos>"
      rejected_format: "{rejected}<eos>"

Current behaviour

But what's happening currently is that it throws KeyError even when passing all the necessary fields required to parse the dataset.

When analyzing using print statements I found that the axolotl.utils.config.validate_config removes some of the fields in ds_cfg[idx]["type"]. It only keeps the field field_system, interestingly it is also the only field that is common with custom SFT dataset types.

Steps to reproduce

Try custom configurations mentioned above.

Config yaml

base_model: abhinand/my-model
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

rl: dpo
datasets:
  - path: argilla/distilabel-intel-orca-dpo-pairs
    split: train
    type:
      field_system: system
      field_prompt: input
      field_chosen: chosen
      field_rejected: rejected
      prompt_format: "<bos>{system}\n\n### Instruction:\n{prompt}\n### Response:\n"
      chosen_format: "{chosen}<eos>"
      rejected_format: "{rejected}<eos>" 

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./models/my-model-dpo

adapter: qlora
lora_model_dir:

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true

lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project: my-model-dpo
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 2
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-7

train_on_inputs: false
group_by_length: false
bf16: false
fp16: true
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

# warmup_steps: 10
warmup_ratio: 0.1
eval_steps:
eval_table_size:
eval_table_max_new_tokens: 128
save_steps: 239
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
save_safetensors: true

Possible solution

So we might have to fix the validator axolotl.utils.config.models.input.v0_4_1.AxolotlInputConfig in order to remove this behavior by forcefully casting custom types to UserDefinedDPOType instead of UserDefinedPrompterType when training type is DPO.

And that is exactly what I did in my fork and it works now -> https://github.com/abhinand5/axolotl/commit/3b0873cc41ef057e7861c6a72ad258ba0d368e84

Please let me know if a PR on this would make sense, because I don't know if I'm doing something wrong by missing some important config that is leading to this issue.

Which Operating Systems are you using?

Python Version

3.11

axolotl branch-commit

main

Acknowledgements

maziyarpanahi commented 2 months ago

Hi @abhinand5 Have you found a workaround for this?

RameshArvind commented 3 weeks ago

Running into the same issue for user defined KTO datasets as well