lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Output inconsistent for autoregressive performer #82

Open GanjinZero opened 2 years ago

GanjinZero commented 2 years ago

I want to apply autoregressive performer for decoding.

import torch
from performer_pytorch import Performer
from torch import nn

attn = Performer(
    dim = 128,
    heads = 4,
    depth = 2,
    local_attn_heads = 2,
    dim_head = 32,
    causal = True,
    use_scalenorm=True
).cuda()
attn.fix_projection_matrices_()
attn.eval()

x = torch.cat([torch.ones((1,50,128)), torch.ones((1,50,128)) * 0.5, torch.ones((1,400,128)) * (-0.1), torch.ones((1,400,128)) * 0.1], dim=1).cuda()
y = attn(x)

x0 = x[0,0:100].unsqueeze(0)
y0 = attn(x0)

print((y[0][0:100]-y0[0][0:100]).norm())

The output is tensor(0.0003, device='cuda:0', grad_fn=).

If I turn off the use_scalenorm, the output is tensor(0.0085, device='cuda:0', grad_fn=). This shows the inconsistent output for autoregressive performer.

GanjinZero commented 2 years ago

I guess the reason is from k = create_kernel(k, is_query = False) from FastAttention.forward, in the softmax_kernel operation, it has line *data_dash = ratio (torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps). torch.amax(data_dash, dim=(-1, -2), keepdim=True))** contains information for later time hidden states, and this information is passed to previous time hidden states.

lucidrains commented 2 years ago

@GanjinZero oh shoot, yea, those maxes are for numerical stability, but i think they should be detached https://github.com/lucidrains/performer-pytorch/commit/fc8b78441b1e27eb5d9b01fc738a8772cee07127 can you let me know if this resolves the issue on your end?