young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

Flash attention does not make memory efficient #81

Closed Taekyoon closed 11 months ago

Taekyoon commented 1 year ago

I tried to apply both scan_attn and scan_mlp, and these configuration causes more memory consumption. I set the same configuration with 1b llama model in v3-8 device and my settings are like this.

python -m EasyLM.models.llama.llama_train \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --total_steps=10000 \
    --log_freq=50 \
    --load_llama_config='1b-scan-both' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --tokenizer.name='beomi/kollama-7b' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=1e-3 \
    --optimizer.adamw_optimizer.end_lr=1e-4 \
    --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
    --optimizer.adamw_optimizer.lr_decay_steps=10000 \
    --optimizer.accumulate_gradient_steps=16 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=64 \
    --train_dataset.json_dataset.tokenizer_processes=1 \
    --checkpointer.save_optimizer_state=True \
    --checkpointer.float_dtype='bf16' \

Do you have any ideas of this issue?

young-geng commented 1 year ago

The context size 2048 is already really small, so memory efficient attention does not make a difference here. Memory efficient attention will be effective when the context length is much larger than the hidden dimension, such as 16k or 32k context. Besides that, in your v3-8 setting, you only have 128GB of memory in total, and the parameters and optimizer states already consume 84GB of them, so I don't see how you can easily train a 7b model with v3-8.

Taekyoon commented 1 year ago

Oh, I tried 1b model with v3-8 now for testing. So, do you mean that if the context size is larger than like 4096, mem efficient attention will work?

young-geng commented 1 year ago

So basically the memory will not grow quadratically with context size, but it will still grow linearly. It is not like if you train a 4k context model, using memory efficient attention will consume less memory than a 2k context model with normal attention.

Taekyoon commented 1 year ago

I got it so you mean that if I train a 4k context model with flash attention, then I can consume less memory than using without flash attention but not less that 2k? Is that correct?

young-geng commented 1 year ago

I believe that should be the case.

Taekyoon commented 1 year ago

Got it :) let me try with 4096 first. Will it be good to share this experiment? or did you already test that?

Taekyoon commented 1 year ago

One more thing :) What if I try to set every scan chunk to be half than current setting? Is that make sense to reduce the mem cost?

Taekyoon commented 1 year ago

@young-geng Found something! :)

I've tried to train model with bf16. And this code line causes more memory consumption.

When it calls this function, dtype is ignored which causes increasing memory consumption caused by type conversion from bf16 into fp32. Now I found this function works in bf16 when I added dtype parameter.

young-geng commented 1 year ago

Good catch! Just committed a fix.

Taekyoon commented 1 year ago

@young-geng I think this issue needs to be reopened because this mem efficient attention doesn't work well both fp32 and bf16. I've tried to apply 1b llama model (context length was 4096) in v3-8. However, the result was slower than the one without flash attention. Especially, scan mlp delays training duration more than 10 times. And, still when I apply flash attention (scan attention only) this consumes more than default training setting. So, I finally conclude that this applied module should be validated in EasyLM project. (I know the module already tested by googlers)

I'm not pretty sure that this applied flash attention perform better than the current model training process in v3-8. Maybe this could work when I use larger device like v4-512, but I got only limited resources. Did you really validate this flash attention or it just added module into llama_model? Please let me know if I can help you.

young-geng commented 1 year ago

I haven't fully tested and optimized it, and that's exactly why the default configuration for using them is set to False. Regarding the slow down, it is expected that when normal attention or MLP fits into memory, using the scanned version would result in slower performance because it prevents XLA from doing some optimizations. Theoretically, it is possible to achieve even better performance than the vanilla one, like how the flash attention CUDA kernel runs faster than the vanilla attention, but this requires low level control over the TPU, which is unfortunately not supported at this time.

Taekyoon commented 1 year ago

Got it! I'm still thinking of testing this module with multi-pods and finding some bottle necks. I'll keep open this issue to share my updates. :)

Taekyoon commented 11 months ago

I think blockwise attention which PR by @lhao499 is working well. I can see reduced memory consumption when I extend sequence length up 2048 => 4096.

I'll close this issue. Thanks to liu :)