NVIDIA / nv-wavenet

Reference implementation of real-time autoregressive wavenet inference
BSD 3-Clause "New" or "Revised" License
735 stars 126 forks source link

nv_wavenet_test.py fails with a larger than one batch size #14

Closed PetrochukM closed 6 years ago

PetrochukM commented 6 years ago

Issue description

Running nv_wavenet_test.py with an increased batch size causes an error.

Offending Code

    model = torch_load("model.pt", torch.device('cuda'))
    wavenet = nv_wavenet.NVWaveNet(**model)
    cond_input = torch_load("cond_input.pt", torch.device('cuda'))
    cond_input = cond_input.repeat(1, 2, 1, 1)  # Increase the batch size to 2

    samples = wavenet.infer(cond_input, nv_wavenet.Impl.AUTO)

Error

GPUassert: invalid argument ../nv_wavenet.cuh 547

System Info

PyTorch version: 0.4.0 Python version: 3.6.4 CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609 CMake version: version 3.5.1

Python version: 3.6 Is CUDA available: Yes CUDA runtime version: 9.0.176 GPU models and configuration: GPU 0: GeForce GTX 1080 Ti Nvidia driver version: 390.30 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21 /usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5 /usr/local/lib/python2.7/dist-packages/torch/lib/libcudnn-7a90c013.so.7.0.5 /usr/local/lib/python3.5/dist-packages/torch/lib/libcudnn-3f9a723f.so.6.0.21

RPrenger commented 6 years ago

@PetrochukM I have added a PR that fixes the issue. I had forgotten to add the batch dimension in the output. The rest of the test will fail in the Python code because the Wave writing code is expecting a 1-d Tensor

PetrochukM commented 6 years ago

@RPrenger Thanks for the PR! Let me test it.

PetrochukM commented 6 years ago

Submitted a PR to fix tests: https://github.com/NVIDIA/nv-wavenet/pull/17


Found oddly that the kernel is deterministic between runs but not between batch elements:

nv-wavenet/pytorch$ python3.6 nv_wavenet_test.py
torch.Size([2, 147800])
tensor([ 121,  141,  115,  139,  121,  133,  125,  145,  138,  145,
         130,  151,  132,  117,  **123**,  134,  119,  119,  **171**,  121,
         142,  151,  152,  149,  165], dtype=torch.int32, device='cuda:0')
tensor([ 132,  144,  143,  129,  144,  117,  119,  123,  132,  140,
         131,  121,  100,   91,  **176**,  146,   93,  140,  **110**,  113,
          96,  168,  160,  102,  158], dtype=torch.int32, device='cuda:0')

nv-wavenet/pytorch$ python3.6 nv_wavenet_test.py
torch.Size([2, 147800])
tensor([ 121,  141,  115,  139,  121,  133,  125,  145,  138,  145,
         130,  151,  132,  117,  **123**,  134,  119,  119,  **171**,  121,
         142,  151,  152,  149,  165], dtype=torch.int32, device='cuda:0')
tensor([ 132,  144,  143,  129,  144,  117,  119,  123,  132,  140,
         131,  121,  100,   91,  **176**,  146,   93,  140,  **110**,  113,
          96,  168,  160,  102,  158], dtype=torch.int32, device='cuda:0')

Using the same code from above, copying the tensors: cond_input = cond_input.repeat(1, 2, 1, 1) # Increase the batch size to 2

Checked the audio output, they both samples seemed resonable. This is more interesting than an issue.