artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
9.92k stars 817 forks source link

FlashAttention support? #221

Open BugReporterZ opened 1 year ago

BugReporterZ commented 1 year ago

This might be more of a general question, but is it possible to use FlashAttention with QLoRA in order to further decrease memory requirements when finetuning?

I would guess that in principle it could be done, but has anybody actually attempted implementing it?

artidoro commented 1 year ago

Hey! Flash attention is orthogonal to QLoRA, meaning that you can combine the two. In fact, we had implemented it for LLaMA at some point but didn't end up keeping it. More generally, if your base model uses flash attention you can use it with QLoRA. I would look for models that implement flash attention or implement it for your favorite base model and then finetune with QLoRA.

BugReporterZ commented 1 year ago

Thanks for replying! Great to learn that there are no inherent issues preventing to combine FlashAttention with QLoRA.

With the latest FlashAttention2 promising even further performance improvements, and given that the memory pressure of standard attention rapidly increases with context size, having built-in support even just for Llama (and now Llama2, which has a native 4k context size) would great. While using QLoRA for finetuning is pretty straightforward, adapting it to take advantage of FlashAttention is not so obvious, considering its low-level nature—there is a so-called monkey patch available for LLaMA but I've not personally been successful in applying it to the QLoRA code.

So, I think that if there is already existing code demonstrating how to use it in practice with QLoRA, it would certainly benefit many if it could be uploaded on the repository.

artidoro commented 1 year ago

That's a good point! I agree we should look into this. If someone wants to contribute an example in the meantime we would appreciate the help.

BugReporterZ commented 1 year ago

Perhaps some of the code from Axolotl could be used. It's a trainer which employs QLoRA and different attention mechanisms, including FlashAttention.

I haven't been able to make FlashAttention work with it yet, but with xformers-attention (another supported method) I could train Llama-13B with 4096 tokens-long sequences within less than 16GB of VRAM (at a batch size of 1), which is almost unbelievable. Training speed did not appear to increase, on the other hand.

It's likely that FlashAttention would yield similar or better benefits and make 30B-class LLMs trainable on a 24 GB GPU with long sequences.

ehartford commented 1 year ago

We would also be very happy to see flash attention 2 support to be added to this tool

LagPixelLOL commented 1 year ago

:octocat:

ehartford commented 1 year ago

I gave this a shot, to implement flash attention in the same way that fastchat and axolotl do.

https://github.com/artidoro/qlora/pull/235

It seems not to work. I was wondering if anyone more familiar with cuda could understand what is going wrong?

@artidoro

the error message is :

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])
pankajarm commented 1 year ago

Same boat here, I try testing by trying both versions of flash attention individually using monkey patching code of FastChat https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py but got stuck on similar error which @ehartford reported Hope to see if someone will figure out what it causing it.

ehartford commented 1 year ago

jon durbin has a fork here that implements flash attention. I'll see if I can adapt the code into a PR

https://github.com/jondurbin/qlora

pankajarm commented 1 year ago

Yup I tried that already it’s throws same error for Llama2 70b. It’s same monkey patch code from FastChat which I tried to integrate.

ehartford commented 1 year ago

here is full stack trace:

Traceback (most recent call last):
  File "/home/eric/git/qlora/qlora.py", line 845, in <module>
    train()
  File "/home/eric/git/qlora/qlora.py", line 807, in train
    train_result = trainer.train()
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    outputs = model(**inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
    return model_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 922, in forward
    return self.base_model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
Traceback (most recent call last):
  File "/home/eric/git/qlora/qlora.py", line 845, in <module>
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
        train()
outputs = run_function(*args)  File "/home/eric/git/qlora/qlora.py", line 807, in train

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    train_result = trainer.train()
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return module(*inputs, output_attentions, None)    return inner_training_loop(

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2654, in training_step
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    loss = self.compute_loss(model, inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/trainer.py", line 2679, in compute_loss
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    outputs = model(**inputs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    hidden_states, self_attn_weights, present_key_value = self.self_attn(    return forward_call(*args, **kwargs)

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
      File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 581, in forward
        return model_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/utils/operations.py", line 569, in __call__
output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 922, in forward
    raise ValueError(    return self.base_model(

  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 685, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 681, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/eric/miniconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 340, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 513, 513), but is torch.Size([1, 513])
ehartford commented 1 year ago

I got around that problem by moving the patch code to an earlier point before loading the model.

But I hit another error:

RuntimeError: FlashAttention only support fp16 and bf16 data type

Then it seems, Flash Attention does not support not support 4-bit, only 16-bit. We will need to await support for 4-bit.

https://github.com/Dao-AILab/flash-attention/issues/398

b0xtch commented 1 year ago

You can also add this part

if torch.cuda.get_device_capability()[0] >= 8:
    from utils.llama_patch import replace_attn_with_flash_attn
    replace_attn_with_flash_attn()
Ltrack commented 1 year ago

I had both these issues. Solved with working training when running

replace_attn_with_flash_attn()

before loading the model and

from utils.llama_patch import upcast_layer_for_flash_attention model = upcast_layer_for_flash_attention(model, torch.bfloat16)

before the training loop. check link below to get the llama_patch file used. Credits to philschmid.

https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/utils/llama_patch.py