Closed BoxiangW closed 3 weeks ago
Hi @cyanguwa , could you help take a look at this PR?
/te-ci pytorch
Hi @timmoon10 , I have signed off my commits. Seems like CIs are waiting for approval right now.
Ah, actually I was incorrect. The softmax_scale
we use here is the reciprocal of the softmax_scale
used in Flash Attention, so norm_factor
is a better name. Better yet, we would want to refactor our implementation so we have the same scaling factor convention as Flash Attention and PyTorch.
Ah, actually I was incorrect. The
softmax_scale
we use here is the reciprocal of thesoftmax_scale
used in Flash Attention, sonorm_factor
is a better name. Better yet, we would want to refactor our implementation so we have the same scaling factor convention as Flash Attention and PyTorch.
Yes, I found this part under FlashAttention
where softmax_scale is 1.0/norm_factor. So I am going to change the kwarg name back to norm_factor
with self.attention_dropout_ctx():
output = attn_forward_func_with_cp(
self.training, query_layer, key_layer, value_layer,
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
self.attention_dropout if self.training else 0.0,
cp_group, cp_global_ranks, cp_stream,
softmax_scale=1.0/self.norm_factor,
qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
attn_mask_type=attn_mask_type,
deterministic=self.deterministic
)
/te-ci pytorch
/te-ci pytorch
Seems like there's a failing test, but after I look at the details, it says all successful.
Description
Add arg option of
norm_factor
into DotProductAttention for MLA supportFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: