thu-ml / SageAttention

Quantized Attention that achieves speedups of 2.1-3.1x and 2.7-5.1x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
Apache License 2.0
608 stars 28 forks source link

LLM acc problem #55

Open laomao0 opened 5 days ago

laomao0 commented 5 days ago

Hi, I setF.scaled_dot_product_attention = sageattn, in modeling_llama.py, and run the inference code, I see it run sageattn_qk_int8_pv_fp16_cuda in sageattention/core.py.

The results are:

LLama-2-7b wiki c4 PIQA ARcE ArcC Hella WinoG AVG FP16 5.12 6.63 78.07 76.26 43.52 57.16 69.22 64.846 sageattn_qk_int8_pv_fp16_cuda 5.1178 6.7139 55.93 28.37 25.09 49.24 49.57 41.64

The PPL is correct, but acc drops a lot.

jason-huang03 commented 4 days ago

hi, which GPU do you use?

shiyuetianqiang commented 4 days ago

Hi, I setF.scaled_dot_product_attention = sageattn, in modeling_llama.py, and run the inference code,

I did this too however,I got ppl is Nan the GPU is A100

laomao0 commented 3 days ago

hi, which GPU do you use?

A100

jason-huang03 commented 2 days ago

Would any one of you kindly provide running script? The code has gone through numerous tests and should have behave correctly.

laomao0 commented 1 day ago

Would any one of you kindly provide running script? The code has gone through numerous tests and should have behave correctly.

may be my cuda version is 11.8? I see it requires 12.4 version.

jt-zhang commented 1 day ago

Hi, I setF.scaled_dot_product_attention = sageattn, in modeling_llama.py, and run the inference code,

I did this too however,I got ppl is Nan the GPU is A100

Hi, you can refer to the following codes.


# Adapted from LlamaAttention.forward
def LlamaSageAttnForward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

    assert not output_attentions, "Output attentions not supported"
    assert attention_mask is None, "Attention mask not supported"
    assert self.num_key_value_groups == 1, "GQA will be supported in near future"
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    if q_len == 1:
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)
    else:
        # do attention with sage attention here

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

    return attn_output, None, past_key_value

# load llama model
model = AutoModelForCausalLM.from_pretrained(...)
for layer in model.model.layers:
    layer.self_attn.forward = MethodType(LlamaSageAttnForward, layer.self_attn)
jt-zhang commented 1 day ago

Also, if CUDA Version < 12.4, you need to update it.

shiyuetianqiang commented 1 day ago

Hi, I setF.scaled_dot_product_attention = sageattn, in modeling_llama.py, and run the inference code,

I did this too however,I got ppl is Nan the GPU is A100

Hi, you can refer to the following codes.

# Adapted from LlamaAttention.forward
def LlamaSageAttnForward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

    assert not output_attentions, "Output attentions not supported"
    assert attention_mask is None, "Attention mask not supported"
    assert self.num_key_value_groups == 1, "GQA will be supported in near future"
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

    if position_embeddings is None:
        cos, sin = self.rotary_emb(value_states, position_ids)
    else:
        cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    if q_len == 1:
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)
    else:
        # do attention with sage attention here

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

    return attn_output, None, past_key_value

# load llama model
model = AutoModelForCausalLM.from_pretrained(...)
for layer in model.model.layers:
    layer.self_attn.forward = MethodType(LlamaSageAttnForward, layer.self_attn)

tanks for your response I found the reason It seems that sageattention are not suitable for multi-gpus very vell. For example, when I use 8 gpus, the attention outs will become nan at the fifth layer(llama27b) In contrast, the results is normal when I just use one gpu

jason-huang03 commented 22 hours ago

Yes, it is observed in PR #50 and we will fix this in the next update.