google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.26k stars 503 forks source link

always loss nan while finetune a few step, wether fp32 or fp16 #24

Closed yongzhuo closed 7 months ago

yongzhuo commented 7 months ago

always loss nan while finetune a few step, wether fp32 or fp16, not stable or other question?

code: https://github.com/yongzhuo/gemma-sft/blob/master/gemma_sft/ft_gemma/train.py

log:

{'loss': 5.6332, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.3049, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 153318.95, 'grad_norm': nan, 'learning_rate': 0.0002, 'epoch': 0.09}
{'loss': 5.7517, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.6231, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.0968, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.8938, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 6.1305, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.9105, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.4063, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.2371, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.4602, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.5166, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.0093, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
yongzhuo commented 7 months ago

emmm, I got it after trail three days.

While training all weights need to be adjusted using fp32/tf32; If you load with fp16 or many fp16(layer-norm=fp32), you will got loss=nan a few steps later.

aaaaaa!!!!!!bye

logs success of gemma-2b:

{'loss': 2.9753, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0844, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.3415, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.4752, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.5717, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0583, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.9449, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.2088, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.4464, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.7823, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.1532, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.9618, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.544, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.8314, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0346, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.3917, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.1504, 'lr': 0.0002, 'epoch': 1.84}