lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

[Bug] Error when `rotary_pos_emb` set to True in cross attention #247

Closed BakerBunker closed 3 months ago

BakerBunker commented 4 months ago
import torch
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim=512, depth=6)
model = CrossAttender(
    dim=512,
    depth=6,
    rotary_pos_emb=True,
    attn_flash=True,
)

nodes = torch.randn(1, 1, 512)
node_masks = torch.ones(1, 1).bool()

neighbors = torch.randn(1, 5, 512)
neighbor_masks = torch.ones(1, 5).bool()

encoded_neighbors = enc(neighbors, mask=neighbor_masks)
model(
    nodes, context=encoded_neighbors, mask=node_masks, context_mask=neighbor_masks
)  # (1, 1, 512)
lucidrains commented 3 months ago

hmm, is the source and target sequence in some shared coordinate space? usually you cannot use rotary embeddings in cross attention

BakerBunker commented 3 months ago

Thank you for explanation, it's my fault to use rotary embedding in cross attention

lucidrains commented 3 months ago

@BakerBunker no problem, i should have added an assert to prevent this in cross attention setting