lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.07k stars 143 forks source link

Cross-attention with arbitrary causal mask #94

Open BarKetPlace opened 1 year ago

BarKetPlace commented 1 year ago

Hi, I am looking for a way to use the code in a causal cross-attention setting, with an arbitrary causal mask that I could specify.

e.g.

import torch
from performer_pytorch import CausalCrossAttention

attn = CausalCrossAttention(
    dim = 512,
    heads = 8
).cuda()

x = torch.randn(1, 1024, 512).cuda()
context = torch.randn(1, 512, 512).cuda()
mask = (torch.arange(1024).reshape(-1,1) > torch.arange(512).reshape(1,-1)).unsqueeze(0)

attn(x, context = context, mask=mask) # (1, 1024, 512)

Is anybody working on that ? I would be happy to do it but would appreciate some pointers/ideas on how to do that with the library.

Thanks !