Open laomao0 opened 5 days ago
hi, which GPU do you use?
Hi, I set
F.scaled_dot_product_attention = sageattn
, in modeling_llama.py, and run the inference code,
I did this too however,I got ppl is Nan the GPU is A100
hi, which GPU do you use?
A100
Would any one of you kindly provide running script? The code has gone through numerous tests and should have behave correctly.
Would any one of you kindly provide running script? The code has gone through numerous tests and should have behave correctly.
may be my cuda version is 11.8? I see it requires 12.4 version.
Hi, I set
F.scaled_dot_product_attention = sageattn
, in modeling_llama.py, and run the inference code,I did this too however,I got ppl is Nan the GPU is A100
Hi, you can refer to the following codes.
# Adapted from LlamaAttention.forward
def LlamaSageAttnForward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions, "Output attentions not supported"
assert attention_mask is None, "Attention mask not supported"
assert self.num_key_value_groups == 1, "GQA will be supported in near future"
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if q_len == 1:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
else:
# do attention with sage attention here
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# load llama model
model = AutoModelForCausalLM.from_pretrained(...)
for layer in model.model.layers:
layer.self_attn.forward = MethodType(LlamaSageAttnForward, layer.self_attn)
Also, if CUDA Version < 12.4, you need to update it.
Hi, I set
F.scaled_dot_product_attention = sageattn
, in modeling_llama.py, and run the inference code,I did this too however,I got ppl is Nan the GPU is A100
Hi, you can refer to the following codes.
# Adapted from LlamaAttention.forward def LlamaSageAttnForward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions, "Output attentions not supported" assert attention_mask is None, "Attention mask not supported" assert self.num_key_value_groups == 1, "GQA will be supported in near future" bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if q_len == 1: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) else: # do attention with sage attention here attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value # load llama model model = AutoModelForCausalLM.from_pretrained(...) for layer in model.model.layers: layer.self_attn.forward = MethodType(LlamaSageAttnForward, layer.self_attn)
tanks for your response I found the reason It seems that sageattention are not suitable for multi-gpus very vell. For example, when I use 8 gpus, the attention outs will become nan at the fifth layer(llama27b) In contrast, the results is normal when I just use one gpu
Yes, it is observed in PR #50 and we will fix this in the next update.
Hi, I set
F.scaled_dot_product_attention = sageattn
, in modeling_llama.py, and run the inference code, I see it runsageattn_qk_int8_pv_fp16_cuda
insageattention/core.py
.The results are:
LLama-2-7b wiki c4 PIQA ARcE ArcC Hella WinoG AVG FP16 5.12 6.63 78.07 76.26 43.52 57.16 69.22 64.846 sageattn_qk_int8_pv_fp16_cuda 5.1178 6.7139 55.93 28.37 25.09 49.24 49.57 41.64
The PPL is correct, but acc drops a lot.