AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

Flash attention - head_dim 64 #1047

Open peregilk opened 4 days ago

peregilk commented 4 days ago

I have tried using MaxText to train Llama 3.2 3B. This seems to work fine with just minor modifications to the configs.

However, I am unable to train the Llama 1B. The reason is that Flash/Splash attention seem to require that the head_dim is divisible by 128. The head_dim of the 1B model is only 64. I get a "not implemented" error. Using dot_product attention for long context lengths is really challenging.

Any ideas?