lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

attn_num_mem_kv > 0 and attn_one_kv_head = True error #228

Closed pfeatherstone closed 6 months ago

pfeatherstone commented 6 months ago

If you set both attn_num_mem_kv and attn_one_kv_head, you get an error.

Repro:

lm = TransformerWrapper (
    num_tokens          = 32,
    max_seq_len         = 0,
    num_memory_tokens   = 20,
    attn_layers = Decoder (
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_onnxable   = True,
        attn_num_mem_kv  = 20,
        attn_one_kv_head = True
    )
)

x = torch.randint(0, 32, (8, 1024))
logits = lm(x)
print(x.shape, x.dtype)
print(logits.shape, logits.dtype)
lucidrains commented 6 months ago

@pfeatherstone ah nice catch

corrected

pfeatherstone commented 6 months ago

The first fix is:

https://github.com/lucidrains/x-transformers/blob/029ec31957a86dbfaa4a9865f0f23403ef31e4aa/x_transformers/x_transformers.py#L789-L792

self.num_mem_kv = num_mem_kv
        if num_mem_kv > 0:
            self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
            self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
pfeatherstone commented 6 months ago

Not so fast

lucidrains commented 6 months ago

@pfeatherstone there's another issue?

pfeatherstone commented 6 months ago

it fails at https://github.com/lucidrains/x-transformers/blob/029ec31957a86dbfaa4a9865f0f23403ef31e4aa/x_transformers/attend.py#L135

lucidrains commented 6 months ago

weird, let me try your example, runs for me

lucidrains commented 6 months ago

oh i see, it is an issue with flash attention

lucidrains commented 6 months ago

ok, fixed for real, thanks for reporting