NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

Change `norm_factor` into `softmax_scale` and add kwarg into `DotProductAttention ` #897

Closed BoxiangW closed 3 weeks ago

BoxiangW commented 3 weeks ago

Description

Add arg option of norm_factor into DotProductAttention for MLA support

Fixes # (issue)

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

BoxiangW commented 3 weeks ago

Hi @cyanguwa , could you help take a look at this PR?

ksivaman commented 3 weeks ago

/te-ci pytorch

BoxiangW commented 3 weeks ago

Hi @timmoon10 , I have signed off my commits. Seems like CIs are waiting for approval right now.

timmoon10 commented 3 weeks ago

Ah, actually I was incorrect. The softmax_scale we use here is the reciprocal of the softmax_scale used in Flash Attention, so norm_factor is a better name. Better yet, we would want to refactor our implementation so we have the same scaling factor convention as Flash Attention and PyTorch.

BoxiangW commented 3 weeks ago

Ah, actually I was incorrect. The softmax_scale we use here is the reciprocal of the softmax_scale used in Flash Attention, so norm_factor is a better name. Better yet, we would want to refactor our implementation so we have the same scaling factor convention as Flash Attention and PyTorch.

Yes, I found this part under FlashAttention where softmax_scale is 1.0/norm_factor. So I am going to change the kwarg name back to norm_factor

            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
                    qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
                    attn_mask_type=attn_mask_type,
                    deterministic=self.deterministic
                )
timmoon10 commented 3 weeks ago

/te-ci pytorch

cyanguwa commented 3 weeks ago

/te-ci pytorch

BoxiangW commented 3 weeks ago

Seems like there's a failing test, but after I look at the details, it says all successful.