lucidrains / rela-transformer

Implementation of a Transformer using ReLA (Rectified Linear Attention) from https://arxiv.org/abs/2104.07012
MIT License
49 stars 7 forks source link

LayerNorm/GatedRMS inconsistency #1

Open inspirit opened 2 years ago

inspirit commented 2 years ago

Hi! looking through pipeline it seems there are some inconsistencies with normalisation

# ReLA
input to GRMSNorm
# att code
output: Linear(inner_dim, dim) + GRMSNorm
# next in FF module 
input to LayerNorm

here we have problem with double norm since we have last layer GRMSNorm in att and first layer LayerNorm in FF.

looking at the paper it seems that in ReLA GRMSNorm is applied to result of mult(attn, v) before output projection not after projection like in this code. I also confused about usage of LayerNorm in FF should it be GRMSNorm instead? not clear from the paper as well

lucidrains commented 2 years ago

@inspirit hello there! yea, i kind of did some improvisation there

i'm using the sandwich normalization formulation from another paper https://arxiv.org/abs/2105.13290 rather than just normalizing the aggregated values directly

for the feedforward, i'm not entirely sure, probably wouldn't make that huge of a difference

inspirit commented 2 years ago

Aha I see, yup i remember sandwich norm paper :) another difference I noticed: you use projection based gating (with Linear layer) in GRMSNorm, while original paper is using simple per element multiplication here: return normed_x *(x*gate).sigmoid() where gate = nn.Parameter(torch.tensor(dim))

lucidrains commented 2 years ago

@inspirit ohh apologies, yea, i didn't build that correctly https://github.com/lucidrains/rela-transformer/commit/b58b121e59463b5c73fa5dea9297c841b1a3a362

let me know if that works! i've seen the relu based attention in another recent paper https://github.com/lucidrains/FLASH-pytorch , so maybe there's something to it!

lucidrains commented 2 years ago

@inspirit how did it go? :) any interesting experimental results?

inspirit commented 2 years ago

it seems to be less stable compared to normal softmax attention, I fused it with preceiver for my experiments, sometimes it gives slightly better results sometimes not :) the reason might be due to a small model inner dimension (128) and more sparse attention due to ReLu use

lucidrains commented 2 years ago

@inspirit yea, i thought it would be too good to be true if relu attention worked :disappointed: it must have worked for FLASH because they confine their quadratic attention to local windows