NVIDIA / CUDALibrarySamples

CUDA Library Samples
Other
1.43k stars 298 forks source link

cuBLASLt FP8 batched gemm with bias #187

Open Sunny-bot1 opened 1 month ago

Sunny-bot1 commented 1 month ago

Hi, when I try to implement cuBLASLt FP8 batched gemm with bias based on LtFp8Matmul, I met this problem.

[2024-05-22 07:06:23][cublasLt][62029][Error][cublasLtMatmulAlgoGetHeuristic] Failed to query heuristics.
cuBLAS API failed with status 15
terminate called after throwing an instance of 'std::logic_error'
  what():  cuBLAS API failed
Aborted (core dumped)

My code:

    cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
    checkCublasStatus(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
    checkCublasStatus(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));

    // create matrix descriptors, we are good with the details here so no need to set any extra attributes
    // table of supported type combinations can be found in the documentation: https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul
    checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8F_E4M3, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
    checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8F_E4M3, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb));
    checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, m, n, ldc));
    checkCublasStatus(cublasLtMatrixLayoutCreate(&Ddesc, CUDA_R_16F, m, n, ldc));
    int batchCount = 2;
    int64_t stridea = m * k;
    int64_t strideb = n * k;
    int64_t stridec = m * n;
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, sizeof(stridea)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, sizeof(strideb)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batchCount, sizeof(batchCount)));
    checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, sizeof(stridec)));

when I implement only batched gemm or only gemm with bias, it can run correctly.

I wonder if cublaslt support FP8 batched gemm with bias. Thank you very much!!!

Sunny-bot1 commented 1 month ago

set CUBLASLT_LOG_LEVEL=5

[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtCreate] lightHandle=0X7FFDB5ABD8F8
[2024-05-22 07:22:32][cublasLt][65020][Info][cublasLtCreate] cuBLASLt v12.4 device #0: multiProcessorCount=92 maxSharedMemPerBlock=49152 maxSharedMemoryPerBlockOptin=101376
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulDescCreate] matmulDesc=0X7FFDB5ABD5E8 computeType=COMPUTE_32F scaleType=0
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55A2EFDA0480 attr=MATMUL_DESC_TRANSA buf=0X7FFDB5ABD5CC sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55A2EFDA0480 attr=MATMUL_DESC_TRANSB buf=0X7FFDB5ABD5D0 sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55A2EFDA0480 attr=MATMUL_DESC_BIAS_POINTER buf=0X7FFDB5ABD560 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulDescSetAttribute] matmulDesc=0X55A2EFDA0480 attr=MATMUL_DESC_EPILOGUE buf=0X7FFDB5ABD5DC sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDB5ABD5F0 type=R_8F_E4M3 rows=32 cols=16 ld=32
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDB5ABD5F8 type=R_8F_E4M3 rows=32 cols=64 ld=32
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDB5ABD600 type=R_16F rows=16 cols=64 ld=16
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutCreate] matLayout=0X7FFDB5ABD608 type=R_16F rows=16 cols=64 ld=16
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872CB0 attr=MATRIX_LAYOUT_BATCH_COUNT buf=0X7FFDB5ABD5E0 sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872CB0 attr=MATRIX_LAYOUT_STRIDED_BATCH_OFFSET buf=0X7FFDB5ABD618 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872D00 attr=MATRIX_LAYOUT_BATCH_COUNT buf=0X7FFDB5ABD5E0 sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872D00 attr=MATRIX_LAYOUT_STRIDED_BATCH_OFFSET buf=0X7FFDB5ABD620 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872D50 attr=MATRIX_LAYOUT_BATCH_COUNT buf=0X7FFDB5ABD5E0 sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872D50 attr=MATRIX_LAYOUT_STRIDED_BATCH_OFFSET buf=0X7FFDB5ABD628 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872DA0 attr=MATRIX_LAYOUT_BATCH_COUNT buf=0X7FFDB5ABD5E0 sizeInBytes=4
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatrixLayoutSetAttribute] matLayout=0X55A2F0872DA0 attr=MATRIX_LAYOUT_STRIDED_BATCH_OFFSET buf=0X7FFDB5ABD628 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulPreferenceCreate] matmulPref=0X7FFDB5ABD610
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulPreferenceSetAttribute] pref=0X55A2F0872E30 attr=MATMUL_PREF_MAX_WORKSPACE_BYTES buf=0X7FFDB5ABD718 sizeInBytes=8
[2024-05-22 07:22:32][cublasLt][65020][Api][cublasLtMatmulAlgoGetHeuristic] Adesc=[type=R_8F_E4M3 rows=32 cols=16 ld=32 batchCount=2 stridedBatchOffset=512] Bdesc=[type=R_8F_E4M3 rows=32 cols=64 ld=32 batchCount=2 stridedBatchOffset=2048] Cdesc=[type=R_16F rows=16 cols=64 ld=16 batchCount=2 stridedBatchOffset=1024] Ddesc=[type=R_16F rows=16 cols=64 ld=16 batchCount=2 stridedBatchOffset=1024] preference=[maxWavesCount=0.0 maxWorkspaceSizeinBytes=33554432] computeDesc=[computeType=COMPUTE_32F scaleType=R_32F transa=OP_T epilogue=EPILOGUE_BIAS biasPointer=0x7f83f3a07000]
[2024-05-22 07:22:32][cublasLt][65020][Error][cublasLtMatmulAlgoGetHeuristic] Failed to query heuristics.
cuBLAS API failed with status 15
terminate called after throwing an instance of 'std::logic_error'
  what():  cuBLAS API failed
Aborted (core dumped)
Sunny-bot1 commented 1 month ago

when I set

cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;

or

cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;

It can run correctly.

only no act mode CUBLASLT_EPILOGUE_BIAS will report the error