lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

ContinuousTransformer num_memory_tokens bug #194

Closed pfeatherstone closed 11 months ago

pfeatherstone commented 11 months ago

Here is a repro:

lm = ContinuousTransformerWrapper(
    dim_in              = 4,
    dim_out             = 256+3,
    max_seq_len         = 0,
    num_memory_tokens   = 20,
    attn_layers = Decoder(
        dim = 512,
        depth = 4,
        heads = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        use_scalenorm   = True,
        attn_onnxable   = True,
        shift_tokens    = 1
    )
)

x = torch.randn(2, 1024, 4)
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = torch.arange(x.shape[1]).unsqueeze(0) < l.unsqueeze(-1)
x = lm(x, mask=m)
lucidrains commented 11 months ago

should be fixed