Closed lucasavila00 closed 2 weeks ago
Maybe candle needs to directly call hgemm_strided_batched
https://github.com/coreylowman/cudarc/blob/64ddbc77a9a84ea4b8bb6b918b34c81beed71b24/src/cublas/result.rs#L222 here https://github.com/huggingface/candle/blob/main/candle-core/src/cuda_backend/mod.rs#L1654 ?
We're actually using this function which calls the generic gemm variant with sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
a few lines below, so this means that it's a f16
kernel (as per the actual kernel name) but the accumulation is done using f32
as detailed here. Using f16 accumulation would indeed be faster but with lower precision so it's a bit unclear to me what the impact of this would end up being and if that would be good enough for most use cases.
Ah, I see, thanks.
We're currently trying to match llama.cpp speed using quantized models, so the loss of precision shouldn't matter for us.
But I can see how it matters for a regular F16 model...
Perhaps I could add this to my fork so we can try it out, and then we can merge it if we find an elegant solution?
@EricLBuehler I'd be glad to benchmark it, profile it etc if you implement it
I forked candle locally and hacked a call to the following function at https://github.com/huggingface/candle/blob/main/candle-core/src/cuda_backend/mod.rs#L1654
unsafe fn gemm_strided_batched<
A: cudarc::driver::DevicePtr<half::f16>,
B: cudarc::driver::DevicePtr<half::f16>,
C: cudarc::driver::DevicePtrMut<half::f16>,
>(
handle: sys::cublasHandle_t,
cfg: StridedBatchedConfig<half::f16>,
a: &A,
b: &B,
c: &mut C,
) -> std::result::Result<(), CublasError> {
let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
result::gemm_strided_batched_ex(
handle,
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const half::f16 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const half::f16 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
This matches the llama.cpp config, and it matches the same used kernels. Llama.cpp is at the bottom.
I perceived no difference in output quality.
Using these settings made it improve by 15% getting mistral.rs to ~1150t/s
Great that it works well with the reduced precision, I've looked a bit at the pytorch codebase and it seems that they use f32 accumulation by default. PyTorch provides an option to disable "reduced precision" here) (which is turned on by default) but this only impacts the truncation setting in SetMathMode. See this issue https://github.com/pytorch/pytorch/issues/123157 .
So to get around this, I've pushed #2141 , this provides a toggle to flip between the reduced precision accumulation and f32 accumulation - which remains the default. It's a global flag so not ideal but at least provides a way to test the reduced precision accumulation, the quantized example has been adapted to use it and indeed benefits from the speedup when using the f16 matmul for the prompt processing. Would that work for your use case?
When it comes to changing the default, it might be better to wait a bit for what happens on the PyTorch side. If models are trained with f32 accumulation, it's a bit unclear to me what the impact will be if one runs inference with a less precise accumulation.
I'm also not sure about making it the default.
The approach of https://github.com/huggingface/candle/pull/2141 fits our use case.
Thank you a lot!
This relates to https://github.com/huggingface/candle/issues/2136
Related to improving mistral.rs prompt processing speed https://github.com/EricLBuehler/mistral.rs/issues/153
Why does candle use
turing_fp16_s1688gemm_fp16_256x128_ldg8_f2f_tn
kernel for F16 matmuls?Llama.cpp uses
turing_h1688gemm_256x128_ldg8_tn
for the same tensor.If I understand it correctly from the docs https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm%5B/url%5D
h-gemm stands for half-gemm where as s-gemm stands for standard F32 gemm.
So, is it possible that candle is not using the best kernel, for some reason?
Is it possible that the candle version is doing the matmuls in F32, as the name would suggest, thus being slower than the other kernel?
Our benchmarks are:
Llama.cpp: ~1500t/s mistral.rs: 1000t/s
And the major contributors are the kernels I mentioned above. Notice the proportion of time spent on each kernel pretty much matches our observed slowdown. More info here https://github.com/EricLBuehler/mistral.rs/issues/153#issuecomment-2081632685