Vchitect / Latte

Latte: Latent Diffusion Transformer for Video Generation.
Apache License 2.0
1.44k stars 147 forks source link

fix xformer input shape of Q,K,V in latte.py #69

Closed tianyma closed 2 months ago

tianyma commented 2 months ago

in the official API of xformers.ops.memory_efficient_attention, the input shape should be

Input tensors must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embeding size per head

but our input is [B, H, M, K], so I fix this by add

            q_xf = q.transpose(1,2).contiguous()
            k_xf = k.transpose(1,2).contiguous()
            v_xf = v.transpose(1,2).contiguous()
            x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
maxin-cn commented 2 months ago

in the official API of xformers.ops.memory_efficient_attention, the input shape should be

Input tensors must be in format [B, M, H, K], where B is the batch size, M the sequence length, H the number of heads, and K the embeding size per head

but our input is [B, H, M, K], so I fix this by add

            q_xf = q.transpose(1,2).contiguous()
            k_xf = k.transpose(1,2).contiguous()
            v_xf = v.transpose(1,2).contiguous()
            x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)

Thanks for your PR.