Open BarKetPlace opened 1 year ago
Hi, I am looking for a way to use the code in a causal cross-attention setting, with an arbitrary causal mask that I could specify.
e.g.
import torch from performer_pytorch import CausalCrossAttention attn = CausalCrossAttention( dim = 512, heads = 8 ).cuda() x = torch.randn(1, 1024, 512).cuda() context = torch.randn(1, 512, 512).cuda() mask = (torch.arange(1024).reshape(-1,1) > torch.arange(512).reshape(1,-1)).unsqueeze(0) attn(x, context = context, mask=mask) # (1, 1024, 512)
Is anybody working on that ? I would be happy to do it but would appreciate some pointers/ideas on how to do that with the library.
Thanks !
Hi, I am looking for a way to use the code in a causal cross-attention setting, with an arbitrary causal mask that I could specify.
e.g.
Is anybody working on that ? I would be happy to do it but would appreciate some pointers/ideas on how to do that with the library.
Thanks !