Fast-LLM got rid of looping to simplify things and optimize the case with a smaller vocab size, so we'll need to bring looping back. Some care will be needed to keep the current performance when looping is unnecessary (looping means multiple read of the logits, etc.).
🔄 Alternatives Considered
We have other implementations, but they are much slower.
📈 Potential Benefits
Faster training and lower memory usage with large vocab sizes.
🧐 Problem Description
Vocab size is limited to 64K (I think?) because of triton's limitation on the block size.
💡 Proposed Solution
The standard way is with looping over blocks, as is done for example in https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py (which was used as a basis for the Fast-LLM implementation).
Fast-LLM got rid of looping to simplify things and optimize the case with a smaller vocab size, so we'll need to bring looping back. Some care will be needed to keep the current performance when looping is unnecessary (looping means multiple read of the logits, etc.).
🔄 Alternatives Considered
We have other implementations, but they are much slower.
📈 Potential Benefits
Faster training and lower memory usage with large vocab sizes.
📝 Additional Context