lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Pass custom scale to flash attention #244

Closed Subuday closed 5 months ago

Subuday commented 5 months ago

In current implementation custom scaled is not passed to flash attention.

lucidrains commented 5 months ago

@Subuday oh shoot, yes i think you are right, thanks!

lucidrains commented 5 months ago

wait, did they introduce the scale kwarg recently? i must have done it this way because they didn't have it previously. we'll need to enforce a certain pytorch version if so

lucidrains commented 5 months ago

@Subuday there is a bug where in the absence of qk norm, custom scales are not applied however, and let me quickly fix that

lucidrains commented 5 months ago

@Subuday let's go with this for now. they didn't have this scale in previous versions

Subuday commented 5 months ago

Thanks!