Closed turboderp closed 6 months ago
None in SDPA or Flash attention is same as 1 / sqrt(d)
scale_attn_weights
is the parameter used to decide whether to use 1 / sqrt(d) or 1 otherwise.
the fp32 arguments is just for stability during training and shouldn't be needed at inference honestly.
Thank you. :+1:
Can you elaborate on the significance of the softmax scaling? I can't find it referenced in the paper, and it seems to be applied differently for each of the three attention methods in the HF implementation:
scale_attention_softmax_in_fp32
,attention_softmax_in_fp32
andscale_attn_weights
are all set.None
, though seems prepared to change it to 1 ifscale_attn_weights
were unset. (?)_flash_attention_forward
, but that argument isn't passed so it defaults to None.Presumably the models are trained with flash-attn so is this just not actually relevant?