NVIDIA / CUDALibrarySamples

CUDA Library Samples
Other
1.5k stars 311 forks source link

FP8 LtFp8Matmul Error #180

Closed jt-zhang closed 4 months ago

jt-zhang commented 4 months ago

After compiling sample_cublasLt_LtFp8Matmul successfully, I run the ./sample_cublasLt_LtFp8Matmul and get:

cuBLAS API failed with status 7
terminate called after throwing an instance of 'std::logic_error'
  what():  cuBLAS API failed
Aborted (core dumped)

I did not change any part of the source code except adding /usr/local/cuda-12.4/include/ into the CMakeLists.txt to ensure the compiling process goes on.

Could anyone help me? Thanks for any tips~

jt-zhang commented 4 months ago

Environment: H100 GCC 11 Cuda 12.4

News: When I compile and run ./sample_cublasLt_LtFp8Matmul in RTX4090, it turns out:

cuBLAS API failed with status 15
terminate called after throwing an instance of 'std::logic_error'
  what():  cuBLAS API failed
Aborted (core dumped)
rsdubtso commented 4 months ago

Hi @jt-zhang,

  1. Can you please run CUBLASLT_LOG_MASK=63 ./sample_cublasLt_LtFp8Matmul and post the output here?

  2. What is the output from ldd ./sample_cublasLt_LtFp8Matmul?

  3. Can you please clarify whether you use H100 or RTX4090? From my testing I cannot reproduce the failure on either.

oscarbg commented 4 months ago

curious.. what CUBLASLT_LOG_MASK=63 does?

rsdubtso commented 4 months ago

curious.. what CUBLASLT_LOG_MASK=63 does?

It is a bitmask of logging flags. See the docs for more info.

jt-zhang commented 4 months ago

@rsdubtso Thank you for your reply! First, I would like to clarify the error in 4090: Environment: RTX 4090, Cuda 12.1, GCC 11.4

Output of CUBLASLT_LOG_MASK=63 ./sample_cublasLt_LtFp8Matmul:

Kernel = cuBLASLt_Fp8Matmul
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtCreate] lightHandle=0X7FFDFE56A8F8
[2024-04-15 13:58:01][cublasLt][115769][Info][cublasLtCreate] cuBLASLt v12.1 device #0: multiProcessorCount=128 maxSharedMemPerBlock=49152 maxSharedMemoryPerBlockOptin=101376
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescCreate] matmulDesc=0X7FFDFE56A5D0 computeType=COMPUTE_32F scaleType=0
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_TRANSA buf=0X7FFDFE56A5C0 sizeInBytes=4
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_TRANSB buf=0X7FFDFE56A5C4 sizeInBytes=4
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_A_SCALE_POINTER buf=0X7FFDFE56A598 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_B_SCALE_POINTER buf=0X7FFDFE56A588 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_C_SCALE_POINTER buf=0X7FFDFE56A578 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_D_SCALE_POINTER buf=0X7FFDFE56A568 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X562D8D6629F0 attr=MATMUL_DESC_AMAX_D_POINTER buf=0X7FFDFE56A560 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDFE56A5D8 type=R_8F_E4M3 rows=128 cols=128 ld=128
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDFE56A5E0 type=R_8F_E4M3 rows=128 cols=128 ld=128
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDFE56A5E8 type=R_16BF rows=128 cols=128 ld=128
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDFE56A5F0 type=R_8F_E4M3 rows=128 cols=128 ld=128
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulPreferenceCreate] matmulPref=0X7FFDFE56A5F8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulPreferenceSetAttribute] pref=0X562D8DD8AD90 attr=MATMUL_PREF_MAX_WORKSPACE_BYTES buf=0X7FFDFE56A6E8 sizeInBytes=8
[2024-04-15 13:58:01][cublasLt][115769][Api][cublasLtMatmulAlgoGetHeuristic] Adesc=[type=R_8F_E4M3 rows=128 cols=128 ld=128] Bdesc=[type=R_8F_E4M3 rows=128 cols=128 ld=128] Cdesc=[type=R_16BF rows=128 cols=128 ld=128] Ddesc=[type=R_8F_E4M3 rows=128 cols=128 ld=128] preference=[maxWavesCount=0.0 maxWorkspaceSizeinBytes=4194304] computeDesc=[computeType=COMPUTE_32F scaleType=R_32F transa=OP_T amaxDPointer=0x7fc4aee0ca00 aScalePointer=0x7fc4aee0c200 bScalePointer=0x7fc4aee0c400 cScalePointer=0x7fc4aee0c600 dScalePointer=0x7fc4aee0c800]
[2024-04-15 13:58:01][cublasLt][115769][Error][cublasLtMatmulAlgoGetHeuristic] Failed to query heuristics.
cuBLAS API failed with status 15
terminate called after throwing an instance of 'std::logic_error'
  what():  cuBLAS API failed
Aborted (core dumped)
jt-zhang commented 4 months ago

Environment: RTX 4090, Cuda 12.1, GCC 11.4 Output of ldd ./sample_cublasLt_LtFp8Matmul:

linux-vdso.so.1 (0x00007ffc8315f000)
libcudart.so.12 => /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12 (0x00007fef0ee00000)
libcublasLt.so.12 => /usr/local/cuda/targets/x86_64-linux/lib/libcublasLt.so.12 (0x00007feeede00000)
libstdc++.so.6 => /lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007feeedbd4000)
libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fef0f15c000)
libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007feeed9ac000)
/lib64/ld-linux-x86-64.so.2 (0x00007fef0f199000)
libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fef0f155000)
libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fef0f150000)
librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fef0f14b000)
libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fef0ed19000)
oscarbg commented 4 months ago

Hi, ~~the problem at least on Ada cards seems to be, they are not supported for FP8.. only Hopper.. it's mentioned in release notes as Hopper only right now.. hope gets added in the future for Ada..~~ sorry thinking about cusparselt..

rsdubtso commented 4 months ago

I think @oscarbg was right to some extent. Ada support was introduced in 12.1 Update 1: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cublas-release-12-1-update-1. So If the CUDA toolkit version above is 12.1 (and not 12.4 as in the original message on the top), then this behavior is expected.

@jt-zhang : can you please double-check the CUDA toolkit version? Or, just post the output of

ls -la /usr/local/cuda/targets/x86_64-linux/lib/libcublasLt.so.12

and I should be able to tell you the version based on the soname.

oscarbg commented 4 months ago

Hi @rsdubtso, sorry for asking here, as it's somewhat unrelated, but is Nvidia planning to enable "full performance" FP8 matmuls via CUBLASLt on Ada? it seems after one year support in 12.1 Updata1 still limited to FP32 compute in FP8 so no performance advatanges over using FP16 matmuls with FP16 compute.. it has been known for a year.. for reference: https://forums.developer.nvidia.com/t/ada-geforce-rtx-4090-fp8-cublaslt-performance/250737 "I was only able to use FP32 for the compute type as this is the only supported mode in cuBLASLt right now." thanks..

jt-zhang commented 4 months ago

@rsdubtso Thank you so much. output of ls -la /usr/local/cuda/targets/x86_64-linux/lib/libcublasLt.so.12:

lrwxrwxrwx 1 root root 24 Feb  9  2023 /usr/local/cuda/targets/x86_64-linux/lib/libcublasLt.so.12 -> libcublasLt.so.12.1.0.26

output of nvcc --version:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

I believe the issue is due to having CUDA version 12.1. I apologize for any inconvenience and thank you for your assistance.

rsdubtso commented 4 months ago

Closing as the original question has been answered.

@oscarbg : have you tried FP8 matmuls with fast accumulation (CUBLASLT_MATMUL_DESC_FAST_ACCUM)? On average, it should give about 1.4x over matmuls FP16 inputs with FP16 accumulate and FP16 outputs. Please open a new issue (or, maybe, update the forum post, if you want to discuss this further).

oscarbg commented 4 months ago

Closing as the original question has been answered.

@oscarbg : have you tried FP8 matmuls with fast accumulation (CUBLASLT_MATMUL_DESC_FAST_ACCUM)? On average, it should give about 1.4x over matmuls FP16 inputs with FP16 accumulate and FP16 outputs. Please open a new issue (or, maybe, update the forum post, if you want to discuss this further).

thanks @rsdubtso!, will try CUBLASLT_MATMUL_DESC_FAST_ACCUM to see if I can get near 1.4 speedup over fp16.. altough I was expecting "some cublaslt mode" to be able to achieve near to 2X speedup in FP8 mode.. hope in a future cublaslt release..