NVIDIA-developer-blog / code-samples

Source code examples from the Parallel Forall Blog
BSD 3-Clause "New" or "Revised" License
1.24k stars 633 forks source link

Getting errors running tensor-cores example #23

Open nmoran opened 5 years ago

nmoran commented 5 years ago

Running the example from the posts/tensor-cores folder as discussed at https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/, it appears the nubmers are not as close as expected. I am getting the following output

./TCGemm 

M = 16384, N = 16384, K = 16384. alpha = 2.000000, beta = 2.000000

Running with wmma...
Running with cuBLAS...

Checking results...
8266.587891 8267.766602
8240.230469 8241.420898
8242.393555 8243.574219
8209.478516 8210.649414
8100.519043 8101.664062
8251.499023 8252.675781
8189.156738 8190.297852
8260.410156 8261.580078
8311.802734 8313.015625
WMMA does not agree with cuBLAS! 268435456 errors!
Ivanrs297 commented 4 years ago

which version of nvcc are you using?, and did you solved it?

agschrei commented 4 years ago

Hey everyone, so I recently ran into the same problem with CUDA 11 and for me it was an issue with the device code that got generated.

If you want to run this sample on Turing you will have to make sure that you are using the -gencode arch=compute_75,code=sm_75 flags during compilation.
Trying to run this on Turing with a binary compiled for a Volta target (sm_70) will provide the error above. I'm guessing the wmma instructions are so low-level that they are not compatible between architectures.

I'm just leaving this here for future reference, hoping I'll save somebody a lot of head-scratching.

yofufufufu commented 1 year ago

Device: RTX3090 In CMakeLists: set(CMAKE_CUDA_ARCHITECTURES 86) NVCC version: 11.1 I get the same issue, anyone can help?

hychiang-git commented 1 year ago

Hi, I have the same issue. Device: A100 NVCC version: 11.1 I tried -arch=sm_80 but it does not work for me.

The results seem correct after reducing MATRIX_M, MATRIX_N, and MATRIX_K from 16384 to 1024. I think the 0.01% relative tolerance and 1e-5 absolute tolerance in the code are too small for large matrix like 16384x16384.

However, I did not get speed up with 1024x1024 matrices: wmma took 0.300032ms cublas took 0.041984ms

I guess we would just use cuBLAS or refering to the faster implementation here.

   // Use tensor cores
   cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));