vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.16k stars 3.83k forks source link

[RFC]: Refactor FP8 kv-cache #4532

Closed comaniac closed 1 month ago

comaniac commented 4 months ago

Motivation.

Support float8_e4m3 for NVIDIA GPUs: The current FP8 kv-cache supports e5m2 on NVIDIA GPUs, and e4m3 on AMD GPUs. While e5m2 seems to be an ideal format for kv-cache storage due to better performance (i.e., e5m2 has the same number of exponent bits as fp16 so its scaling overhead could be smaller than e4m3), model checkpoints with FP8 weights mostly adopt e4m3 format. As a result, in order to support FP8 models in vLLM, we need float8_e4m3 kv-cache for both vendors.

Refactor FP8 kv-cache: In the current implementation, we use macros to dispatch FP8 tensor quantization logic in vLLM custom kernels. For example, in the current cache_kernel.cu, we use the following code to quantize tensors for CUDA or ROCm:

#if defined(ENABLE_FP8_E5M2) // Only for CUDA
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3) // Only for ROCm
      key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
      value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
      assert(false);
#endif

This creates the following issues:

  1. We cannot enable e4m3 and e5m2 in the same build.
  2. The function arguments are fixed.
  3. Any vendor specific changes need to touch cache_kernel.cu (so does attention_kernel.cu).
  4. It is hard to be extended in the future to support more data types (e.g., INT8) for kv-cache.

Proposed Change.

User Interface

  1. The default data type fp8 aliases to float8_e4m3 for all vendors.
  2. fp8_e4m3 and fp8_e5m2 are also available for users. We let vendor backend throw errors if a particular format is not supported in the current vLLM build.
  3. When running FP16 model with FP8 kv-cache, we optionally accept kv-cache per-tensor scaling factors in a JSON format. If not provided, we always use scaling factor 1. This is compatible with the current e5m2 strategy.
  4. When running FP8 model, we load kv-cache scaling factor from the model checkpoint.

Backend Interface

  1. All vendor specific FP8 KV-cache related kernels are in csrc/quantization/fp8/<vendor>/quant_utils.cuh.
  2. The only vendor specific logic in cache_kernel.cu and attention_kernel.cu is as follows. Note that we currently only have 2 GPU vendors so I suppose this is still acceptable.
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif
  1. We define the following interfaces for each vendor to implement.

Roadmap

Feedback Period.

Feedback can be provided directly on PR. Based on comments here can update the RFC to elaborate.

CC List.

@robertgshaw2-neuralmagic @tlrmchlsmth @pcmoritz @zhuohan123 @WoosukKwon @HaiShaw @zhaoyang-star

Any Other Things.

No response

pcmoritz commented 4 months ago

Thanks a lot for putting together this RFC! This sounds like a solid plan to me.

Some more detailed comments:

comaniac commented 4 months ago

Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

robertgshaw2-neuralmagic commented 4 months ago

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

But we could have an fp16 model with scales for the kv cache saved in the safetensors files:

model.layers.0.attn.k_proj.weight
model.layers.0.attn.k_cache.act_scale

So we would therefore not need to support the JSON case

pcmoritz commented 4 months ago

I'm +1 to supporting activation scales in the FP16 checkpoint and not in JSON. This way less configurations need to be supported and everything is uniform :)

pcmoritz commented 4 months ago

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

Sounds good! I get why the data type needs to be passed throught, what I don't really get is why "auto" needs to be handled in C++ -- it seems to me it could be mapped to a concrete type in python and then c++ only needs to handle the cases of concrete types. But I might be wrong, feel free to do what feels most natural while implementing this :)

comaniac commented 4 months ago

Thanks for the valuable feedback!

if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  // Make sure we don't call " fp8::scaled_convert" when FP8 is disabled.
  key_cache[tgt_key_idx] = tgt_key;
  value_cache[tgt_value_idx] = tgt_value;
} else {
  key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
  value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}

Then we could have:

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
  switch (kv_dt) {
#ifdef ENABLE_FP8
  case Fp8KVCacheDataType::kFp8E4m3:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
  case Fp8KVCacheDataType::kFp8E5m2:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
#endif
  default: // Only this branch left when FP8 is disabled, so always throw an error.
    assert(false);
  }
}
pcmoritz commented 4 months ago

Ah I see, in that case, kAuto is a good name since it is the same as "auto" in python. I didn't realize it required a special code path :)

HaiShaw commented 4 months ago

w.r.t. We cannot enable e4m3 and e5m2 in the same build. If we look to have a build with both supported on a same newer hardware, most likely we won't need both formats function simultaneously, as that increases the complexity with no use case, considering only e4m3 would be used in forward and inferencing computations. On contrary, e5m2 would only be practically feasible on older hardware to be a storage type with acceptable performance from mantissa rounding, have e4m3 enabled on older hardware isn't beneficial, neither to cost of cast, nor to computation (no hardware support). Finally, have a build being generic across generations of GPUs seems to be unnecessary.

HaiShaw commented 4 months ago

w.r.t. When running FP8 model, we load kv-cache scaling factor from the model checkpoint. We shall have serialized checkpoint with various scaling factors defined, to both the stationary scaling factors (for weights, at whatever granularity), and updatable scaling factors (activations, KV caches), and to the later we need to define the update process with quantizer flow included.

HaiShaw commented 4 months ago
A wrapper function convert that converts Tin to Tout with a particular FP8 format. For example, when writing values to kv-cache, Tin=uint16_t, Tout=uint8_t, kv_dt=kFp8E4M3

Over time, it makes more sense to rule out uint8_t and move to use torch fp8 types, then kv_dt would be unnecessary.

jon-chuang commented 1 month ago

@comaniac can this be closed? as of: https://github.com/vllm-project/vllm/pull/4893