vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
23.76k stars 3.41k forks source link

[RFC]: OpenAI Triton-only backend #5083

Open bringlein opened 2 months ago

bringlein commented 2 months ago

Motivation.

Recently, the OpenAI Triton backend for AMD hardware PR 3643 was merged, which is so far the only flash attention backend with the source code part of vLLM. Some of the advantages of OpenAI Triton are superior platform and performance portability. Therefore, we (@tdoublep and myself) wanted to investigate if this code could work equally well on a different platform, i.e. NVIDIA GPUs.

Our experiments show that using the code contributed by AMD on different NVIDIA hardware (A100, L40, H100) results in competitive prefill performance compared to the default option (flash_attn). For smaller number of heads, which may be the case when using tensor parallelism, it is even faster.

image image image

For this experiments, we used the code contributed by AMD, but replaced the autotuner options with options more suited for the different GPUs. However, we did not change the actual Triton code.

Therefore, could we consider a Triton-only backend? While this does not (yet) result in a performance advantage in all cases, there are a number of additional technical motivations:

  1. Using mainly Triton code, it would reduce the dependency on hand-written cuda code (e.g. ~3500 LoC for the different variants of the forward kernel in flash_attn vs ~500 LoC Triton code).
  2. It would improve the platform portability of vLLM.
  3. And consequently, help future proofing vLLM.

Proposed Change.

We propose to add a new backend that runs the flash attention Triton code on both NVIDIA and AMD platforms. We would propose to start with the existing flash attention, but we also want to discuss the option for other kernels. We would also contribute our additional options for the Triton autotuner, so that the results of the blue curves above could be achieved.

Feedback Period.

No response

CC List.

@hongxiayang @WoosukKwon @Yard1 @jpvillam-amd

Any Other Things.

No response

robertgshaw2-neuralmagic commented 2 months ago

I think this is a good idea

Question --> what would you do about the other (non-attention) custom ops?

We also do not have these parameterized in the way we do for attention + will likely need to update this

hongxiayang commented 1 month ago

@bringlein Thanks for putting together this RFC, and shared your experiment results. Agreed that this is a good idea to have the backend for easier portability for various GPUS, and requiring much fewer lines of code (as you mentioned: LoC of 3500 vs 500) while not hurting performance is the way to go.

bringlein commented 1 month ago

Thanks for your feedback!

@robertgshaw2-neuralmagic

Question --> what would you do about the other (non-attention) custom ops? We also do not have these parameterized in the way we do for attention + will likely need to update this

Yes, exactly, this is the next/bigger question in this context.
We also implemented some of the other kernels in Triton with the goal of having a "triton-only" vLLM deployment. In general, the performance of the Triton kernels are competitive, or even faster then the cuda kernels in vLLM. For example, the RMSNorm could be up to 70% faster (wR in the Figure means the RMSNorm "with residual"): image

Right now, for developing and debugging, my code selects between cuda and triton kernels based on environment variables...but I think this should be done in a better way ;).

But maybe this could be an incremental process: First, starting with the "triton-only" backend (where paged attention is an open question for us), since this abstraction is already in place. And subsequently (or in parallel) think of an architecture/abstraction for a parameterized selection of the other kernels?