unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.4k stars 1.29k forks source link

Fix: cast logits to float32 in cross_entropy_forward to prevent errors #1254

Closed Erland366 closed 2 weeks ago

Erland366 commented 2 weeks ago

1251 mentioned that there's different data type between branches. Therefore we need to upcast it to float32

I saw other logits calculation also use tl.float32 so it should be correct

default continue pretraining give this amount of VRAM so I think this shouldn't increase any usage too .-.

image image

Erland366 commented 2 weeks ago

image

tested on colab too

itshahmir commented 2 weeks ago

I am facing this issue in this notebook:

https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing#scrollTo=yqxqAZ7KJ4oL

image