lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

[Question] very small attention scores #245

Closed pfeatherstone closed 2 months ago

pfeatherstone commented 5 months ago

This is a general question when training transformers, but is related to a specific question on this repo. In my tests I've found curiously that the attention layer's dot product matrix, sometimes called "scores", sometimes "dots" has very small values. In my case, all the non-masked elements have values in range [10^-9, 10^-8]. The masked elements have value -3.4028*10^+38 as expected. The attention map, which is just softmax(dots), yields 0 for masked elements and then exactly the same value for non-masked elements, presumably due to numerical imprecision. Has anybody ever seen this? If so, what are potential workarounds? This repo offers setting attn_qk_norm which just L2 normalizes queries and keys. Does this solve the problem? Many thanks

pfeatherstone commented 5 months ago

I'll give it a go thanks

pfeatherstone commented 5 months ago

Funny how some publications can just be a case of: add conv there and see what happens

pfeatherstone commented 5 months ago

@lucidrains Good news, using attn_qk_norm seems to have solved my problem. Now, all my attention "scores"/"dots" are order O(1), except for masked elements which are -3.4028*10^+38 as expected. So the softmaxed attention map is now looking more sensible.

It might be worth mentioning in the README that attn_qk_norm can have this nice property. You mention already it can help with overflowing but it seems it can help with underflowing, or whatever this is.

pfeatherstone commented 5 months ago

Unfortunately, talking_heads isn't compatible with flash attention. I can't afford not to use flash attention. I also had a look at sparse_topk thinking that would also help, but again, not compatible with flash attention. Makes sense.

lucidrains commented 5 months ago

@pfeatherstone nice! yea i'm bullish on cosine sim attention. Tero Karras recently used it in his new u-net with great results

pfeatherstone commented 5 months ago

Makes you wonder, what percentage of a model is just some kind of normalization. Probably quite high. That seems like a flaw. Someone needs to invent a new neural network architecture where normalization is like < 1% of your layers.

pfeatherstone commented 5 months ago

@pfeatherstone nice! yea i'm bullish on cosine sim attention. Tero Karras recently used it in his new u-net with great results

What's the state of https://github.com/lucidrains/flash-cosine-sim-attention ? I like the idea of fusing flash attention with l2-normalized kv.

Also, did you consider using https://github.com/NVIDIA/cutlass for the CUDA backend? I think Tri Dao used that library for Flash Attention 2 and allowed him to write much more concise and ultimately better code. (According to a podcast interview)