I have tried using MaxText to train Llama 3.2 3B. This seems to work fine with just minor modifications to the configs.
However, I am unable to train the Llama 1B. The reason is that Flash/Splash attention seem to require that the head_dim is divisible by 128. The head_dim of the 1B model is only 64. I get a "not implemented" error. Using dot_product attention for long context lengths is really challenging.
I have tried using MaxText to train Llama 3.2 3B. This seems to work fine with just minor modifications to the configs.
However, I am unable to train the Llama 1B. The reason is that Flash/Splash attention seem to require that the head_dim is divisible by 128. The head_dim of the 1B model is only 64. I get a "not implemented" error. Using dot_product attention for long context lengths is really challenging.
Any ideas?