Open rayleizhu opened 6 months ago
BTW, I noticed that you mentioned in the blog that
And even cooler, these kernels are actually faster than the built in alternatives (CuBLAS and FlashAttention2)!
This is unsurprising as flash attention 1 & 2 are designed for training-phase speed (with a large batch size). Flash decoding should be a stronger baseline.
@rayleizhu To make torch.compile
work with 3rd party ops, you need to register it. I'll put up an example of how to do this later.
This is unsurprising as flash attention 1 & 2 are designed for training-phase speed (with a large batch size). Flash decoding should be a stronger baseline.
I certainly agree :), and I would expect FlashDecoding to be faster than the torch.compile generated ops. But FlashDecoding is not integrated into PyTorch yet.
mark! :)
I found the examples here.
However, I have another question: is the registration required by torch.cuda.CUDAGraph() or torch._dynamo? Do I still need this registration if I want to define the graph manually with torch.cuda.CUDAGraph() instead of capturing it with Dynamo?
I've tried the torch.library
approach, and ran into some problems which I've outlined here: https://github.com/pytorch/pytorch/issues/120441
Your repro works for me with pytorch-nightly. TORCH_COMPILE_DEBUG give me this:
def forward(self, arg0_1: "bf16[1, 2, 2, 4]", arg1_1: "bf16[1, 5, 2, 4]", arg2_1: "bf16[1, 5, 2, 4]", arg3_1: "bf16[1, 1, 2, 4]", arg4_1: "bf16[1, 1, 2, 4]", arg5_1: "i32[1]"):
# File: /home/yifu/pytorch/torch/_dynamo/external_utils.py:25 in inner, code: return fn(*args, **kwargs)
auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.mylib.custom_func.default, q = arg0_1, k_cache = arg1_1, v_cache = arg2_1, k = arg3_1, v = arg4_1, cache_seqlens = arg5_1); arg0_1 = arg3_1 = arg4_1 = arg5_1 = None
getitem: "bf16[1, 2, 2, 4]" = auto_functionalized[0]
getitem_1: "bf16[1, 5, 2, 4]" = auto_functionalized[1]
getitem_2: "bf16[1, 5, 2, 4]" = auto_functionalized[2]; auto_functionalized = None
copy_: "bf16[1, 5, 2, 4]" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = None
copy__1: "bf16[1, 5, 2, 4]" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = None
return (getitem,)
Maybe give the latest nightly a shot?
I found the examples here.
However, I have another question: is the registration required by torch.cuda.CUDAGraph() or torch._dynamo? Do I still need this registration if I want to define the graph manually with torch.cuda.CUDAGraph() instead of capturing it with Dynamo?
The answer is: no. If anyone has difficulties using Dynamo, consider CUDAGraph. See the blog post here. However, you need to make sure that the graph is static (be careful with if/for statements).
I'm trying to replace F.scaled_dot_product_attention with flash decoding kernel for faster inference.
However, while the flash decoding function works well in the eager mode, I cannot make it work with torch.compile(). It seems that torch.comile() does not support such third-party operators. How can I overcome this problem?
My code is like:
And the error message with
--compile
option is: