Closed tianyma closed 2 months ago
in the official API of
xformers.ops.memory_efficient_attention
, the input shape should beInput tensors must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embeding size per head
but our input is [B, H, M, K], so I fix this by add
q_xf = q.transpose(1,2).contiguous() k_xf = k.transpose(1,2).contiguous() v_xf = v.transpose(1,2).contiguous() x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
Thanks for your PR.
in the official API of
xformers.ops.memory_efficient_attention
, the input shape should bebut our input is [B, H, M, K], so I fix this by add