Open rmrao opened 3 years ago
I tried to reproduce this issue using a master build as well as the 1.7.1
binaries with CUDA11.0 on a TitanV, 2080Ti, and P100 and get non-zero results:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.get_device_name())
q, k = torch.load('failing_tensors.pt')
res = torch.bmm(q, k.transpose(1, 2))
print(res)
> 1.7.1
11.0
Tesla P100-SXM2-16GB
tensor([[[-7.6692e-01, -4.2144e-01, -1.3237e-02, ..., 4.7409e-02,
4.7409e-02, 4.7409e-02],
I try to grab a TitanXP to reproduce it.
The issue is also not reproducible with a master build + CUDA11.2 and the 1.7.1
binaries + CUDA11.0 on a Titan-XP using a loop of 100k iterations.
🐛 Bug
Using the pytorch 1.7.1 conda binary, the same operation gives me the correct answer when using cuda 10.2, but gives all zeros when using cuda 11.0.
To Reproduce
Steps to reproduce the behavior:
failing_tensors.pt
q, k = torch.load("failing_tensors.pt")
torch.bmm(q, k.transpose(1, 2))
Step 3. succeeds when using cuda 10.2, but provides either all zeros or all nans when using cuda 11.0.
Environment
PyTorch version: 1.7.1 Is debug build: False CUDA used to build PyTorch: 11.0 ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64) GCC version: (Ubuntu 8.4.0-1ubuntu1~18.04) 8.4.0 Clang version: Could not collect CMake version: version 3.19.0-rc1
Python version: 3.8 (64-bit runtime) Is CUDA available: True CUDA runtime version: 11.0.194 GPU models and configuration: GPU 0: TITAN Xp GPU 1: TITAN Xp
Nvidia driver version: 450.51.05 cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7 HIP runtime version: N/A MIOpen runtime version: N/A
Versions of relevant libraries: [pip3] numpy==1.19.2 [pip3] torch==1.7.1 [conda] blas 1.0 mkl [conda] cudatoolkit 11.0.221 h6bb024c_0 [conda] mkl 2020.2 256 [conda] mkl-service 2.3.0 py38he904b0f_0 [conda] mkl_fft 1.2.1 py38h54f3939_0 [conda] mkl_random 1.1.1 py38h0573a6f_0 [conda] numpy 1.19.2 py38h54aff64_0 [conda] numpy-base 1.19.2 py38hfa32c7d_0 [conda] pytorch 1.7.1 py3.8_cuda11.0.221_cudnn8.0.5_0 pytorch
cc @csarofeen @ptrblck @xwang233 @ngimel