lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Support for NormSoftmax #215

Closed catid closed 7 months ago

catid commented 7 months ago

Based on this paper: https://openreview.net/pdf?id=4g7nCbpjNwd

Would require editing this line:

https://github.com/lucidrains/x-transformers/blob/aabee05d6bca6d74646156009159c55f8d27d884/x_transformers/attend.py#L278C70-L278C75

And replacing the * scale with:

    if self.norm_softmax:
        dots = dots / torch.clamp(dots.std(dim=-1, keepdim=True), min=1e-6)
    else:
        dots *= scale

And then something similar in the other flash attention path

lucidrains commented 7 months ago

@catid oh interesting, reminds me a bit of https://arxiv.org/abs/2005.09561

there will also be a temperature involved

have you tried this? maybe i can run a quick experiment tonight

lucidrains commented 7 months ago

it won't be compatible with flash attention

catid commented 7 months ago

NormSoftmax CIFAR-10 benchmark results at epoch=60 using ViT-tiny: baseline : 77.69% sqrtd: 76.39% inf: 77.53%

NormSoftmax CIFAR-10 benchmark results at epoch=300 using ViT-tiny: baseline: 85.19% inf: 85.07%

Manages to get about the same result without the extra parameters

lucidrains commented 7 months ago

@catid well yea, so they claim. cifar-10 is a tiny benchmark too

lucidrains commented 7 months ago

another engineering obstacle would be handling a masked standard dev

lucidrains commented 7 months ago

yea, let me run it tonight on enwik8, but if i don't see anything notable on the first or second try, probably will just drop this

catid commented 7 months ago

@lucidrains The masked stddev is like this right? https://github.com/catid/cifar10deepspeed/blob/fe5b399c5ab5f3ed11235d3dbe72952ce7c2be46/models/vit_small.py#L75

I think that's what I'm testing

lucidrains commented 7 months ago

@catid i'm thinking for autoregressive text generation (gpt), the triangular causal mask. you are masking out the diagonal?

catid commented 7 months ago

Yeah I'm just copying your vit_for_small_dataset.py

lucidrains commented 7 months ago

@catid ohh ok, do you see anything? have you ran the experiments yourself? never trust anything a paper says unless you see the curves in front of you 😆

catid commented 7 months ago

The results I shared above are from my setup

lucidrains commented 7 months ago

@catid wow! ok, i actually put a lot of weight from results from internet randos

ok, let me try it tonight!

lucidrains commented 7 months ago

@catid wait, your results show norm softmax to be worse than baseline? is that accuracy?

lucidrains commented 7 months ago

@catid can you share a wandb report with training curves?

catid commented 7 months ago

I dunno I mean the numbers are pretty close and I only ran N=1 trial so not sure if one method produces better accuracy than the other. Also I don't have wandb integrated into my scripts yet (haven't learned how to use that yet).

lucidrains commented 7 months ago

ah, looks to be a negative result.