Closed yongzhuo closed 7 months ago
emmm, I got it after trail three days.
While training all weights need to be adjusted using fp32/tf32; If you load with fp16 or many fp16(layer-norm=fp32), you will got loss=nan a few steps later.
aaaaaa!!!!!!bye
logs success of gemma-2b:
{'loss': 2.9753, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0844, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.3415, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.4752, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.5717, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0583, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.9449, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.2088, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.4464, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.7823, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.1532, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.9618, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.544, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 2.8314, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.0346, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.3917, 'lr': 0.0002, 'epoch': 1.84}
{'loss': 3.1504, 'lr': 0.0002, 'epoch': 1.84}
always loss nan while finetune a few step, wether fp32 or fp16, not stable or other question?
code: https://github.com/yongzhuo/gemma-sft/blob/master/gemma_sft/ft_gemma/train.py
log: