xfactlab / orpo

Official repository for ORPO
Apache License 2.0
410 stars 36 forks source link

Recreating the setup with CUDA 12.1; grad norm is nan #31

Closed Jayant1234 closed 2 months ago

Jayant1234 commented 3 months ago

I am trying to recreate the setup but with CUDA 12.1, I am installing following versions of libraries:

requirements.txt

cachetools==5.3.3
appdirs==1.4.4
Jinja2==3.1.2
bitsandbytes==0.43.1
numpy==1.26.4
einops==0.7.0
networkx==3.2.1
ninja==1.11.1.1
pillow==10.2.0
torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121
triton==2.1.0
typing_extensions==4.8.0
termcolor==2.4.0
protobuf==3.20.3
datasets==2.17.0
fsspec
transformers
tokenizers
huggingface
peft
sympy
tqdm
wandb
wheel
packaging
accelerate
flash-attn==2.5.6 --no-build-isolation

I am running the following script:

    accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
    --lr 5e-6 \
    --lr_scheduler_type inverse_sqrt \
    --alpha 0.1 \
    --torch_compile False \
    --warmup_steps 200 \
    --model_name mistralai/Mistral-7B-v0.1 \
    --data_name argilla/ultrafeedback-binarized-preferences-cleaned \
    --num_train_epochs 1 \
    --prompt_max_length 1792\
    --response_max_length 2048 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --gradient_accumulation_steps 1 \
    --num_proc 1 \
    --flash_attention_2 

The training run continues like this and results in grad_norm is nan instability:

{'loss': 1.2294, 'grad_norm': 9.907266616821289, 'learning_rate': 1.25e-06, 'epoch': 0.01} {'loss': 1.1362, 'grad_norm': 11.611167907714844, 'learning_rate': 2.5e-06, 'epoch': 0.03} {'loss': 1.1392, 'grad_norm': 9.815629959106445, 'learning_rate': 3.7500000000000005e-06, 'epoch': 0.04} {'loss': 1.1631, 'grad_norm': 10.303666114807129, 'learning_rate': 5e-06, 'epoch': 0.05} {'loss': 1.1713, 'grad_norm': 15.603175163269043, 'learning_rate': 4.47213595499958e-06, 'epoch': 0.07} {'loss': 1.1804, 'grad_norm': 32.802734375, 'learning_rate': 4.082482904638631e-06, 'epoch': 0.08} {'loss': 1.0949, 'grad_norm': nan, 'learning_rate': 3.7796447300922724e-06, 'epoch': 0.09}

Could you please help and do you have any suggestions on what's going wrong with the setup?

jiwooya1000 commented 3 months ago

Hello @Jayant1234,

I have never experienced such a situation. Could you specify the library versions that will reproduce the error? Also, could you try running with TRL & alignment-handbook to check if the error persists?

Jayant1234 commented 3 months ago

Hi @jiwooya1000, thanks for getting back! I added recent changes made by you, especially making sure the padding is on the right, which seems to have resolved the issue! I really admire how your work with ORPO and how it implements a self-reference so innovatively!

For other's reference, I ran thepip -r requirements.txt on a new created pip environment where requirements.txt was:

cachetools==5.3.3 appdirs==1.4.4 Jinja2==3.1.2 bitsandbytes==0.43.1 numpy==1.26.4 einops==0.7.0 networkx==3.2.1 ninja==1.11.1.1 pillow==10.2.0 torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121 triton==2.1.0 typing_extensions==4.8.0 termcolor==2.4.0 protobuf==3.20.3 datasets==2.17.0 fsspec transformers tokenizers huggingface peft sympy tqdm wandb wheel packaging accelerate flash-attn==2.5.6 --no-build-isolation