huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Candle won't use half-gemm from cublas when doing fp16 matmul #2139

Closed lucasavila00 closed 2 weeks ago

lucasavila00 commented 2 weeks ago

This relates to https://github.com/huggingface/candle/issues/2136

Related to improving mistral.rs prompt processing speed https://github.com/EricLBuehler/mistral.rs/issues/153

Why does candle use

turing_fp16_s1688gemm_fp16_256x128_ldg8_f2f_tn kernel for F16 matmuls?

Llama.cpp uses

turing_h1688gemm_256x128_ldg8_tn for the same tensor.

image

image

If I understand it correctly from the docs https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-gemm%5B/url%5D

h-gemm stands for half-gemm where as s-gemm stands for standard F32 gemm.

So, is it possible that candle is not using the best kernel, for some reason?

Is it possible that the candle version is doing the matmuls in F32, as the name would suggest, thus being slower than the other kernel?

Our benchmarks are:

Llama.cpp: ~1500t/s mistral.rs: 1000t/s

And the major contributors are the kernels I mentioned above. Notice the proportion of time spent on each kernel pretty much matches our observed slowdown. More info here https://github.com/EricLBuehler/mistral.rs/issues/153#issuecomment-2081632685

lucasavila00 commented 2 weeks ago

Maybe candle needs to directly call hgemm_strided_batched https://github.com/coreylowman/cudarc/blob/64ddbc77a9a84ea4b8bb6b918b34c81beed71b24/src/cublas/result.rs#L222 here https://github.com/huggingface/candle/blob/main/candle-core/src/cuda_backend/mod.rs#L1654 ?

LaurentMazare commented 2 weeks ago

We're actually using this function which calls the generic gemm variant with sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, a few lines below, so this means that it's a f16 kernel (as per the actual kernel name) but the accumulation is done using f32 as detailed here. Using f16 accumulation would indeed be faster but with lower precision so it's a bit unclear to me what the impact of this would end up being and if that would be good enough for most use cases.

lucasavila00 commented 2 weeks ago

Ah, I see, thanks.

We're currently trying to match llama.cpp speed using quantized models, so the loss of precision shouldn't matter for us.

But I can see how it matters for a regular F16 model...

EricLBuehler commented 2 weeks ago

Perhaps I could add this to my fork so we can try it out, and then we can merge it if we find an elegant solution?

lucasavila00 commented 2 weeks ago

@EricLBuehler I'd be glad to benchmark it, profile it etc if you implement it

lucasavila00 commented 2 weeks ago

I forked candle locally and hacked a call to the following function at https://github.com/huggingface/candle/blob/main/candle-core/src/cuda_backend/mod.rs#L1654

unsafe fn gemm_strided_batched<
    A: cudarc::driver::DevicePtr<half::f16>,
    B: cudarc::driver::DevicePtr<half::f16>,
    C: cudarc::driver::DevicePtrMut<half::f16>,
>(
    handle: sys::cublasHandle_t,
    cfg: StridedBatchedConfig<half::f16>,
    a: &A,
    b: &B,
    c: &mut C,
) -> std::result::Result<(), CublasError> {
    let alpha = cfg.gemm.alpha;
    let beta = cfg.gemm.beta;
    result::gemm_strided_batched_ex(
        handle,
        cfg.gemm.transa,
        cfg.gemm.transb,
        cfg.gemm.m,
        cfg.gemm.n,
        cfg.gemm.k,
        (&alpha) as *const half::f16 as *const _,
        *a.device_ptr() as *const _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.lda,
        cfg.stride_a,
        *b.device_ptr() as *const _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.ldb,
        cfg.stride_b,
        (&beta) as *const half::f16 as *const _,
        *c.device_ptr_mut() as *mut _,
        sys::cudaDataType_t::CUDA_R_16F,
        cfg.gemm.ldc,
        cfg.stride_c,
        cfg.batch_size,
        sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
        sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
    )
}

This matches the llama.cpp config, and it matches the same used kernels. Llama.cpp is at the bottom.

image

I perceived no difference in output quality.

Using these settings made it improve by 15% getting mistral.rs to ~1150t/s

LaurentMazare commented 2 weeks ago

Great that it works well with the reduced precision, I've looked a bit at the pytorch codebase and it seems that they use f32 accumulation by default. PyTorch provides an option to disable "reduced precision" here) (which is turned on by default) but this only impacts the truncation setting in SetMathMode. See this issue https://github.com/pytorch/pytorch/issues/123157 .

So to get around this, I've pushed #2141 , this provides a toggle to flip between the reduced precision accumulation and f32 accumulation - which remains the default. It's a global flag so not ideal but at least provides a way to test the reduced precision accumulation, the quantized example has been adapted to use it and indeed benefits from the speedup when using the f16 matmul for the prompt processing. Would that work for your use case?

When it comes to changing the default, it might be better to wait a bit for what happens on the PyTorch side. If models are trained with f32 accumulation, it's a bit unclear to me what the impact will be if one runs inference with a less precise accumulation.

lucasavila00 commented 2 weeks ago

I'm also not sure about making it the default.

The approach of https://github.com/huggingface/candle/pull/2141 fits our use case.

Thank you a lot!