Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.49k stars 1.36k forks source link

FlashAttention works with single GPU, but crash with accelerate DP on multiple GPU (FlashAttention only support fp16 and bf16 data type) #822

Open Andcircle opened 9 months ago

Andcircle commented 9 months ago

System Info

`Accelerate` version: 0.22.0
Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.29
Python version: 3.8.10
Numpy version: 1.23.1
PyTorch version (GPU?): 2.0.1+cu117 (True)
PyTorch XPU available: False
PyTorch NPU available: False
System RAM: 1121.81 GB
GPU type: NVIDIA A100-SXM4-80GB

transformers              4.37.2
trl                       0.7.11.dev0
flash-attn                2.5.2

out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Reproduction

The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type


import os
import wandb

import torch
from accelerate import Accelerator
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments
)

from trl import DataCollatorForCompletionOnlyLM

from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, AutoTokenizer

import sys
project_root = '/'.join(os.path.dirname(__file__).split('/')[:-1])
print(project_root)
sys.path.append(project_root)
from utils.meta_loader import write_meta, read_meta

import transformers

# bench
alpha = 16
rank = 64
batch_size = 2
length = 4096
accumlate_steps = 1
lr = 5e-5

train_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/train")
eval_dataset = load_from_disk("/mnt/localssd/dataset/llava_processed_dataset/test")    

run_name = "llava_debug"
save_dir = "/mnt/localssd/llava_debug"

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # load_in_8bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
    # llm_int8_skip_modules=["multi_modal_projector"]
)

model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    # "llava-hf/bakLlava-v1-hf",
    quantization_config=bnb_config,
    trust_remote_code=True, 
    device_map={'':torch.cuda.current_device()},
    torch_dtype=torch.float16,
    use_flash_attention_2=True
    )

target_modules = [
    "*language_model.*q_proj", 
    "*language_model.*k_proj", 
    "*language_model.*v_proj", 
    "*language_model.*o_proj", 
    "*language_model.*gate_proj", 
    "*language_model.*up_proj", 
    "*language_model.*down_proj", 
    "*language_model.*lm_head"]

modules_to_save = ["linear_1", "linear_2"]

peft_config = LoraConfig(
    lora_alpha=alpha,
    lora_dropout=0.1,
    r=rank,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
    modules_to_save=modules_to_save
)

tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)

training_arguments = TrainingArguments(
    output_dir=save_dir,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=accumlate_steps,
    optim="paged_adamw_32bit",
    save_steps=500,
    logging_steps=10,
    learning_rate=lr,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=100,
    warmup_ratio=0.03,
    # group_by_length=True,
    lr_scheduler_type="constant",
    run_name=run_name,
    evaluation_strategy="steps",
    eval_steps=200,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=True,
    # weight_decay=0.01,
    # dataloader_num_workers=NUM_PROC//2
)

model.config.use_cache = False # not use for fine tuning

def test_data_collator(datas):
    result = {}
    input_ids = [torch.Tensor(d['input_ids']) for d in datas]
    attention_mask = [torch.Tensor(d['attention_mask']) for d in datas]
    pixel_values = [torch.Tensor(d['pixel_values']) for d in datas]
    labels = [torch.Tensor(d['labels']) for d in datas]

    result['input_ids'] = torch.concat(input_ids).type(torch.int64)
    result['attention_mask'] = torch.concat(attention_mask).type(torch.int64)
    result['pixel_values'] = torch.concat(pixel_values)
    result['labels'] = torch.concat(labels).type(torch.int64)
    return result

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_arguments,
    data_collator=test_data_collator
)

trainer.train()```

### Expected behavior

Expect the behavior should be the same for both single GPU and Multi GPU
tridao commented 9 months ago

I'm not familiar with accelerate or how transformers uses FlashAttention, you'd probably get better help asking on those repos.

ArthurZucker commented 9 months ago

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
ArthurZucker commented 9 months ago
>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True)

....

File ~/miniconda3/envs/py310/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

tridao commented 9 months ago

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

ArthurZucker commented 9 months ago

The error is before that, but it seems it's torch nightly, the transformers snippet works with torch2.2 ! (vs getting the FlashAttention only support fp16 and bf16 data type with nightly)
So more reliable. (I am getting RuntimeError: q must be on CUDA with my snippet on torch2.2 so different error)

tridao commented 9 months ago

I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with attn_implementation="flash_attention_2" and the corresponding PR on transformers.

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
  • flash_attn=2.5.3 + torch nightly so (2.3 ish)

I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version).

yiakwy-xpu-ml-framework-team commented 8 months ago

this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is

The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).

Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout.

Actually they are equivalent with this mapping:

    def permute(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.nheads, self.headim)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".

# mha_fwd https://github.com/Dao-AILab/flash-attention/blob/6bbc532388e61185a92e2a563126739967b4c8c5/csrc/flash_attn/flash_api.cpp#L339-L339
    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
    TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
    // We will support Turing in the near future
    // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");

    auto q_dtype = q.dtype();
    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
                "FlashAttention only support fp16 and bf16 data type");
    if (q_dtype == torch::kBFloat16) {
        TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
    }

I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:

    if q.dtype == torch.float32:
        q = q.to(torch.float16, non_blocking=True)
        k = k.to(torch.float16, non_blocking=True)
        v = v.to(torch.float16, non_blocking=True)
    elif q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
        capability = torch.cuda.get_device_capability()
        if capability[0] <= 8:
            raise RuntimeError("Flash attention for FP8 (need hoper TE support) is currently only supported for compute capability >= 80")
        else:
            # TODO (yiakwy) : add FP8 support
            raise NotImplemented

    output = flash_attn_func(q, k, v, dropout_p=self.dropout.p, causal=is_causal)
    output = revert_mold_flash_attn_input(output)

    if output_attentions:
        raise Exception("Does not support output attention weights inside flash attention.")

    if output.dtype != torch.float32:
        # TODO (yiakwy) : add support of fp16 and bf16
        # if output dtype is not FP32 (by default Flash attetnion generate FP16 output), we need to cast it back
        output = output.to(torch.float32, non_blocking=True)

So we need to update the error information, right ?

thepowerfuldeez commented 8 months ago

I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format! I rolled back to torch2.2 stable and reinstalled flash-attn and now it works.