Open ChuanhongLi opened 12 months ago
Hi @ChuanhongLi
The doc is a bit outdated here, and indeed Flash-Decoding supports BlockDiagonalCausalWithOffsetPaddedKeysMask
now :)
Hi @ChuanhongLi The doc is a bit outdated here, and indeed Flash-Decoding supports
BlockDiagonalCausalWithOffsetPaddedKeysMask
now :)
Wow! Thanks for your quick reply. @danthe3rd I use pip to install the xformers, and the version is 0.0.22.post4. But in xformers/ops/fmha/flash.py, I can not find "BlockDiagonalCausalWithOffsetPaddedKeysMask".
I miss it ?
Thanks!
It will use the split_k
backend, which is also a Flash-Decoding implementation but in Triton
https://github.com/facebookresearch/xformers/blob/001590cef2258f9d70f1c2edd632057ae45bcf8e/xformers/ops/fmha/triton_splitk.py#L490
It will use the
split_k
backend, which is also a Flash-Decoding implementation but in Triton
@danthe3rd Thanks! I will have a try!
It will use the
split_k
backend, which is also a Flash-Decoding implementation but in Triton https://github.com/facebookresearch/xformers/blob/001590cef2258f9d70f1c2edd632057ae45bcf8e/xformers/ops/fmha/triton_splitk.py#L490@danthe3rd Thanks! I will have a try!
@danthe3rd Hi, there may be something wrong with triton_splitk.py. When I run llama_inference example with CodeLlama-34b with one query whose length is 13388 tokens, xformers.ops.fmha.triton_splitk.FwOp is choosed, but I get a error as line 538. To find the problem, I add some logs. When doing inference, the input query length is 1, but after line 535(q_len = seqinfo.min_seqlen), q_len changed to seqinfo.min_seqlen(here seqinfo.min_seqlen: 13388??? ). Then line 536(if q_len != seqinfo.max_seqlen) is not satisfied, since seqinfo.max_seqlen is 1. Therefore triton_splitk.FwOp can not be used.
seqinfo.min_seqlen = 13388 seqinfo.max_seqlen = 1
It may be wrong here? Forget to update or something else?
When I just ignore line 535(q_len = seqinfo.min_seqlen), it works ok.
Thanks!
cc @bottler @mpu
There might be something wrong indeed... This backend shouldn't be used for prompt encoding, and when predicting next-token decoding you should have a single query (so min_seqlen = max_seqlen = 1
)
❓ Questions and Help
As mentioned in issue https://github.com/facebookresearch/xformers/issues/894, "memory_efficient_attention will automatically use the Flash-Decoding algorithm if it is supported (it requires bf16/f16, no bias, and A100 or newer GPU)",but codes in model.py call memory_efficient_attention_forward with attn_bias, and line 15-17 in model.py:
from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, )
According to your comments and the code, memory_efficient_attention will never choose Flash-Decoding algorithm?
I wonder if I misunderstand something.
Thanks!