Closed Taekyoon closed 11 months ago
The context size 2048 is already really small, so memory efficient attention does not make a difference here. Memory efficient attention will be effective when the context length is much larger than the hidden dimension, such as 16k or 32k context. Besides that, in your v3-8 setting, you only have 128GB of memory in total, and the parameters and optimizer states already consume 84GB of them, so I don't see how you can easily train a 7b model with v3-8.
Oh, I tried 1b model with v3-8 now for testing. So, do you mean that if the context size is larger than like 4096, mem efficient attention will work?
So basically the memory will not grow quadratically with context size, but it will still grow linearly. It is not like if you train a 4k context model, using memory efficient attention will consume less memory than a 2k context model with normal attention.
I got it so you mean that if I train a 4k context model with flash attention, then I can consume less memory than using without flash attention but not less that 2k? Is that correct?
I believe that should be the case.
Got it :) let me try with 4096 first. Will it be good to share this experiment? or did you already test that?
One more thing :) What if I try to set every scan chunk to be half than current setting? Is that make sense to reduce the mem cost?
@young-geng Found something! :)
I've tried to train model with bf16. And this code line causes more memory consumption.
When it calls this function, dtype is ignored which causes increasing memory consumption caused by type conversion from bf16 into fp32.
Now I found this function works in bf16 when I added dtype
parameter.
Good catch! Just committed a fix.
@young-geng I think this issue needs to be reopened because this mem efficient attention doesn't work well both fp32 and bf16. I've tried to apply 1b llama model (context length was 4096) in v3-8. However, the result was slower than the one without flash attention. Especially, scan mlp delays training duration more than 10 times. And, still when I apply flash attention (scan attention only) this consumes more than default training setting. So, I finally conclude that this applied module should be validated in EasyLM project. (I know the module already tested by googlers)
I'm not pretty sure that this applied flash attention perform better than the current model training process in v3-8. Maybe this could work when I use larger device like v4-512, but I got only limited resources. Did you really validate this flash attention or it just added module into llama_model? Please let me know if I can help you.
I haven't fully tested and optimized it, and that's exactly why the default configuration for using them is set to False. Regarding the slow down, it is expected that when normal attention or MLP fits into memory, using the scanned version would result in slower performance because it prevents XLA from doing some optimizations. Theoretically, it is possible to achieve even better performance than the vanilla one, like how the flash attention CUDA kernel runs faster than the vanilla attention, but this requires low level control over the TPU, which is unfortunately not supported at this time.
Got it! I'm still thinking of testing this module with multi-pods and finding some bottle necks. I'll keep open this issue to share my updates. :)
I think blockwise attention which PR by @lhao499 is working well. I can see reduced memory consumption when I extend sequence length up 2048 => 4096.
I'll close this issue. Thanks to liu :)
I tried to apply both scan_attn and scan_mlp, and these configuration causes more memory consumption. I set the same configuration with 1b llama model in v3-8 device and my settings are like this.
Do you have any ideas of this issue?