Closed suachong closed 2 weeks ago
The compute type for FP16 Tensor core wrapper around cublas GEMMEx needs to be changed from CUBLAS_COMPUTE_16F to CUBLAS_COMPUTE_32F. It originally had CUDA_R_32F, but was changed to CUBLAS_16F during the rocblas->hipblas changes in PyTorch.
CUBLAS_COMPUTE_16F
CUBLAS_COMPUTE_32F
CUDA_R_32F
CUBLAS_16F
https://github.com/ROCm/apex/commit/4fa061dbfe4bc181ed879713b1aa48e1499ff907#diff-91b286344e75403377c428565cccd16bee300d73a6d6346a6b36443f50d22dd2R138
This change allows the unit test for test_fused_dense.py to pass successfully.
test_fused_dense.py
/opt/rocm/apex/apex/contrib/test/fused_dense# pytest -v ============================================================================== test session starts =============================================================================== platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0 -- /opt/conda/envs/py_3.9/bin/python cachedir: .pytest_cache hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/opt/rocm-6.1.0/apex/apex/contrib/test/fused_dense/.hypothesis/examples') rootdir: /opt/rocm-6.1.0/apex plugins: xdist-3.3.1, xdoctest-1.1.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, cpp-2.3.0 collected 1 item Running 1 items in this shard: apex/contrib/test/fused_dense/test_fused_dense.py::FusedDenseTest::test_fused_dense test_fused_dense.py::FusedDenseTest::test_fused_dense PASSED [100%] =============================================================================== 1 passed in 6.83s ================================================================================
The compute type for FP16 Tensor core wrapper around cublas GEMMEx needs to be changed from
CUBLAS_COMPUTE_16F
toCUBLAS_COMPUTE_32F
. It originally hadCUDA_R_32F
, but was changed toCUBLAS_16F
during the rocblas->hipblas changes in PyTorch.https://github.com/ROCm/apex/commit/4fa061dbfe4bc181ed879713b1aa48e1499ff907#diff-91b286344e75403377c428565cccd16bee300d73a6d6346a6b36443f50d22dd2R138
This change allows the unit test for
test_fused_dense.py
to pass successfully.