bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.14k stars 616 forks source link

AdEMA NaN when loading from state_dict #1382

Open darius-lam opened 5 days ago

darius-lam commented 5 days ago

System Info

Running a standard training loop where I save the optimizer state_dict using opt.state_dict(). Upon loading using opt.load_state_dict() to resume, the model immediately NaNs after the first backprop step.

This only occurs using the AdEMA optimizer:

bnb.optim.AdEMAMix8bit(model.parameters(), lr=lr, t_alpha=T, t_beta3=T)

AdamW and others load state dict perfectly fine. Any ideas?

Reproduction

` opt = bnb.optim.AdEMAMix8bit(model.parameters())

run training loop

torch.save(opt.state_dict(), "dt.pt")

try resuming opt from state_dict later

opt.load_state_dict("dt.pt")

run training loop again

`

Expected behavior

Optimizer should resume training without NaNning