Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.11k stars 1.32k forks source link

Value-dim differnet from query-dim/key-dim is not supported #952

Open hahnyuan opened 5 months ago

hahnyuan commented 5 months ago

Hello flash_attn Maintainers and Community,

When attempting to execute the provided code snippet utilizing the flash_attn library, a runtime error occurs with the following message:

import torch
import flash_attn

bs=1
seqlen=128
qkdim=128
vdim=256
nheads=8
q=torch.ones(bs,seqlen,nheads,qkdim,dtype=torch.float16).cuda()
k=torch.ones(bs,seqlen,nheads,qkdim,dtype=torch.float16).cuda()
v=torch.ones(bs,seqlen,nheads,vdim,dtype=torch.float16).cuda()

attn=flash_attn.flash_attn_func(q,k,v)

Result: RuntimeError: v must have shape (batch_size, seqlen_k, num_heads_k, head_size_og)

It seems that the error is due to a mismatch between the dimensions of the value tensor v and the query and key tensors q and k, respectively.

Is it possible to extend support to scenarios where the dimensionality of the value tensor v differs from that of the query and key tensors q and k?

tridao commented 5 months ago

No we don't plan to support that.