xfactlab / orpo

Official repository for ORPO
Apache License 2.0
421 stars 39 forks source link

4xA6000 training - failed to save model #3

Closed rkinas closed 8 months ago

rkinas commented 8 months ago

Hi, thank you for providing ORPO. I ran quick training but after finishing it did not saved model - it crashed.

My setup:

I took oryginal fsdp.yaml configuration (no changes).

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
    --lr 5e-6 \
    --warmup_steps 100 \
    --model_name openchat/openchat-3.5-0106 \
    --data_name argilla/dpo-mix-7k \
    --num_train_epochs 1 \
    --prompt_max_length 512 \
    --response_max_length 2048 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --num_proc 1

Best regards Remek

jiwooya1000 commented 8 months ago

Hello Remek, I appreciate your interest in ORPO!

Can you share the error log while saving the model?

We've noticed that the model is not correctly saved with FSDP even when the model is saved in the main branch, so we are working on fixing the issue in FSDP.

jiwooya1000 commented 8 months ago

Hello, we have managed the conflict between torch.compile in TrainingArguments and FSDP model saving.

Although we are not sure about the actual error on your side, the problem on our side was the model being saved with _orig_mod. prepended to the weight map:

{
  "metadata": {
    "total_size": 5559367680
  },
  "weight_map": {
    "_orig_mod.lm_head.bias": "model-00002-of-00002.safetensors",
    "_orig_mod.lm_head.weight": "model-00002-of-00002.safetensors",
    "_orig_mod.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
    "_orig_mod.model.final_layernorm.bias": "model-00002-of-00002.safetensors",
    "_orig_mod.model.final_layernorm.weight": "model-00002-of-00002.safetensors",

From our settings (A6000 cluster & A100 cluster), we checked that the model is correctly saved and loaded with AutoModelForCausalLM.from_pretrained with the latest commit.

Could you check if this fix resolves your issue by trying the latest version?

rkinas commented 8 months ago

Hi, thank you for answer. It turned out that removing the generation_config.json file from the model directory (I had to download the OpenChat3.5 model locally) solved the problem. Now the model saves correctly.

ValueError: The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. Fix these issues to save the configuration.
rkinas commented 8 months ago

BTW: I used ORPO on OpenChat3.5-0106 which is SFT and probably after DPO. ORPO scored on MT-Bench higher (first turn is much higher).

jiwooya1000 commented 8 months ago

Thank you for sharing the nice result! Did you train the model with the code above?

rkinas commented 8 months ago

Yes, 100% the same code - only change in config files and saved tokenizer at final stage.

jiwooya1000 commented 8 months ago

Glad to see that ORPO is also giving promising results on fine-tuned chat models too😀 Closing the issue as the model saving issue is resolved, thank you again for sharing the nice result!