facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.5k stars 607 forks source link

Questions about llama_inference example #903

Open ChuanhongLi opened 12 months ago

ChuanhongLi commented 12 months ago

❓ Questions and Help

As mentioned in issue https://github.com/facebookresearch/xformers/issues/894, "memory_efficient_attention will automatically use the Flash-Decoding algorithm if it is supported (it requires bf16/f16, no bias, and A100 or newer GPU)",but codes in model.py call memory_efficient_attention_forward with attn_bias, image and line 15-17 in model.py:

from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, )

According to your comments and the code, memory_efficient_attention will never choose Flash-Decoding algorithm?

I wonder if I misunderstand something.

Thanks!

danthe3rd commented 12 months ago

Hi @ChuanhongLi The doc is a bit outdated here, and indeed Flash-Decoding supports BlockDiagonalCausalWithOffsetPaddedKeysMask now :)

ChuanhongLi commented 12 months ago

Hi @ChuanhongLi The doc is a bit outdated here, and indeed Flash-Decoding supports BlockDiagonalCausalWithOffsetPaddedKeysMask now :)

Wow! Thanks for your quick reply. @danthe3rd I use pip to install the xformers, and the version is 0.0.22.post4. But in xformers/ops/fmha/flash.py, I can not find "BlockDiagonalCausalWithOffsetPaddedKeysMask". image

I miss it ?

Thanks!

danthe3rd commented 12 months ago

It will use the split_k backend, which is also a Flash-Decoding implementation but in Triton https://github.com/facebookresearch/xformers/blob/001590cef2258f9d70f1c2edd632057ae45bcf8e/xformers/ops/fmha/triton_splitk.py#L490

ChuanhongLi commented 12 months ago

It will use the split_k backend, which is also a Flash-Decoding implementation but in Triton

https://github.com/facebookresearch/xformers/blob/001590cef2258f9d70f1c2edd632057ae45bcf8e/xformers/ops/fmha/triton_splitk.py#L490

@danthe3rd Thanks! I will have a try!

ChuanhongLi commented 12 months ago

It will use the split_k backend, which is also a Flash-Decoding implementation but in Triton https://github.com/facebookresearch/xformers/blob/001590cef2258f9d70f1c2edd632057ae45bcf8e/xformers/ops/fmha/triton_splitk.py#L490

@danthe3rd Thanks! I will have a try!

@danthe3rd Hi, there may be something wrong with triton_splitk.py. image When I run llama_inference example with CodeLlama-34b with one query whose length is 13388 tokens, xformers.ops.fmha.triton_splitk.FwOp is choosed, but I get a error as line 538. To find the problem, I add some logs. image When doing inference, the input query length is 1, but after line 535(q_len = seqinfo.min_seqlen), q_len changed to seqinfo.min_seqlen(here seqinfo.min_seqlen: 13388??? ). Then line 536(if q_len != seqinfo.max_seqlen) is not satisfied, since seqinfo.max_seqlen is 1. Therefore triton_splitk.FwOp can not be used.

seqinfo.min_seqlen = 13388 seqinfo.max_seqlen = 1

It may be wrong here? Forget to update or something else?

When I just ignore line 535(q_len = seqinfo.min_seqlen), it works ok.

Thanks!

danthe3rd commented 12 months ago

cc @bottler @mpu There might be something wrong indeed... This backend shouldn't be used for prompt encoding, and when predicting next-token decoding you should have a single query (so min_seqlen = max_seqlen = 1)