NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.16k stars 1.35k forks source link

Fixed compute type for FP16 Tensor core wrapper around cublas GEMMEx #1808

Closed suachong closed 2 weeks ago

suachong commented 2 weeks ago

The compute type for FP16 Tensor core wrapper around cublas GEMMEx needs to be changed from CUBLAS_COMPUTE_16F to CUBLAS_COMPUTE_32F. It originally had CUDA_R_32F, but was changed to CUBLAS_16F during the rocblas->hipblas changes in PyTorch.

https://github.com/ROCm/apex/commit/4fa061dbfe4bc181ed879713b1aa48e1499ff907#diff-91b286344e75403377c428565cccd16bee300d73a6d6346a6b36443f50d22dd2R138

This change allows the unit test for test_fused_dense.py to pass successfully.

/opt/rocm/apex/apex/contrib/test/fused_dense# pytest -v 
============================================================================== test session starts ===============================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0 -- /opt/conda/envs/py_3.9/bin/python
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/opt/rocm-6.1.0/apex/apex/contrib/test/fused_dense/.hypothesis/examples')
rootdir: /opt/rocm-6.1.0/apex
plugins: xdist-3.3.1, xdoctest-1.1.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, cpp-2.3.0
collected 1 item                                                                                                                                                                 
Running 1 items in this shard: apex/contrib/test/fused_dense/test_fused_dense.py::FusedDenseTest::test_fused_dense

test_fused_dense.py::FusedDenseTest::test_fused_dense PASSED                                                                                                               [100%]

=============================================================================== 1 passed in 6.83s ================================================================================