pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.65k stars 22.25k forks source link

Unexpected slow dropout in stacked RNN/LSTM/GRU #50879

Open gaelm opened 3 years ago

gaelm commented 3 years ago

🐛 Bug

Dropout with a manually implemented stacked version of RNN/LSTM/GRU (aka split_fw below) is faster than the standard pytorch RNN/LSTM/GRU module (aka std_fw below).

Here is the profiler analysis for 20 runs.

std_fw: 0.118355s
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::_cudnn_rnn        99.56%        2.355s        99.73%        2.359s     117.932ms        2.385s        99.96%        2.386s     119.278ms           0 b           0 b       3.10 Gb      -2.03 Gb            20  
                  aten::empty         0.14%       3.315ms         0.14%       3.315ms      23.679us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       5.14 Gb       5.14 Gb           140  
                  aten::zeros         0.06%       1.389ms         0.13%       2.984ms     149.181us     152.675us         0.01%     417.696us      20.885us           0 b           0 b       1.55 Mb           0 b            20  
                   aten::set_         0.05%       1.264ms         0.05%       1.264ms      63.177us     183.099us         0.01%     183.099us       9.155us           0 b           0 b           0 b           0 b            20  
               aten::rnn_tanh         0.04%     921.044us        99.87%        2.362s     118.098ms     182.123us         0.01%        2.386s     119.298ms           0 b           0 b       3.10 Gb           0 b            20  
                 aten::select         0.04%     850.081us         0.06%       1.521ms      38.023us     121.574us         0.01%     121.574us       3.039us           0 b           0 b           0 b           0 b            40  
             aten::as_strided         0.03%     670.827us         0.03%     670.827us      16.771us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            40  
                  aten::fill_         0.03%     608.481us         0.03%     608.481us      30.424us     175.976us         0.01%     175.976us       8.799us           0 b           0 b           0 b           0 b            20  
                  aten::zero_         0.02%     487.727us         0.05%       1.096ms      54.810us      89.045us         0.00%     265.021us      13.251us           0 b           0 b           0 b           0 b            20  
    aten::cudnn_is_acceptable         0.02%     412.776us         0.02%     412.776us      20.639us      42.123us         0.00%      42.123us       2.106us           0 b           0 b           0 b           0 b            20  
             aten::contiguous         0.02%     396.091us         0.02%     396.091us      19.805us      49.202us         0.00%      49.202us       2.460us           0 b           0 b           0 b           0 b            20  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.365s
CUDA time total: 2.386s

split_fw: 0.0514706s
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
             aten::_cudnn_rnn        95.95%     983.971ms        97.13%     995.994ms      16.600ms     990.313ms        96.48%     995.946ms      16.599ms           0 b           0 b       3.04 Gb      -6.58 Gb            60  
                  aten::empty         1.12%      11.484ms         1.12%      11.484ms      22.968us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      10.89 Gb      10.89 Gb           500  
                 aten::stride         0.72%       7.340ms         0.72%       7.340ms      15.291us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           480  
                   aten::set_         0.36%       3.729ms         0.36%       3.729ms      62.142us       3.694ms         0.36%       3.694ms      61.561us           0 b           0 b           0 b           0 b            60  
         aten::_fused_dropout         0.25%       2.542ms         1.19%      12.212ms     305.294us      16.737ms         1.63%      16.737ms     418.431us           0 b           0 b       1.27 Gb           0 b            40  
                 aten::select         0.24%       2.487ms         0.44%       4.477ms      37.309us       4.484ms         0.44%       4.484ms      37.366us           0 b           0 b           0 b           0 b           120  
               aten::rnn_tanh         0.24%       2.442ms        98.05%        1.006s      16.759ms       4.621ms         0.45%        1.005s      16.757ms           0 b           0 b       3.04 Gb           0 b            60  
             aten::as_strided         0.19%       1.990ms         0.19%       1.990ms      16.582us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           120  
                  aten::zeros         0.18%       1.828ms         0.59%       6.049ms     100.812us       1.410ms         0.14%       2.612ms      43.526us           0 b           0 b       1.55 Mb           0 b            60  
                aten::dropout         0.17%       1.697ms         1.36%      13.909ms     347.727us       1.639ms         0.16%      18.376ms     459.399us           0 b           0 b       1.27 Gb           0 b            40  
                  aten::fill_         0.15%       1.586ms         0.15%       1.586ms      26.431us     674.315us         0.07%     674.315us      11.239us           0 b           0 b           0 b           0 b            60  
                  aten::zero_         0.13%       1.342ms         0.29%       2.928ms      48.799us     526.995us         0.05%       1.201ms      20.022us           0 b           0 b           0 b           0 b            60  
    aten::cudnn_is_acceptable         0.11%       1.176ms         0.11%       1.176ms      19.605us       1.167ms         0.11%       1.167ms      19.451us           0 b           0 b           0 b           0 b            60  
             aten::contiguous         0.11%       1.155ms         0.11%       1.155ms      19.244us       1.149ms         0.11%       1.149ms      19.145us           0 b           0 b           0 b           0 b            60  
             aten::empty_like         0.07%     704.742us         0.15%       1.530ms      38.256us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       1.01 Gb           0 b            40  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.025s
CUDA time total: 1.026s

To Reproduce

import timeit
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.profiler as profiler

class StdRNN(nn.Module):
    def __init__(self, dropout, bidirectional):
        super().__init__()
        self.rnn = nn.RNN(14, 53, 3, dropout=dropout, bidirectional=bidirectional)

    def forward(self, src):
        return self.rnn(src)

class SplitRNN(nn.Module):
    def __init__(self, dropout, bidirectional):
        super().__init__()
        self.dropout = dropout
        factor = 2 if bidirectional else 1
        self.rnn1 = nn.RNN(14, 53, 1, dropout=0, bidirectional=bidirectional)
        self.rnn2 = nn.RNN(factor * 53, 53, 1, dropout=0, bidirectional=bidirectional)
        self.rnn3 = nn.RNN(factor * 53, 53, 1, dropout=0, bidirectional=bidirectional)

    def forward(self, src):
        output, _ = self.rnn1(src)
        if self.dropout > 0:
            output = F.dropout(output, self.dropout, self.training)
        output, _ = self.rnn2(output)
        if self.dropout > 0:
            output = F.dropout(output, self.dropout, self.training)
        return self.rnn3(output)

def dropout_bug(dropout, bidirectional=True, device="cuda:0", timeit_count=20):
    device = torch.device(device)
    src = torch.randn(1000, 64, 14).to(device)

    std_rnn = StdRNN(dropout, bidirectional).to(device)
    with profiler.profile(use_cuda=True, with_stack=True, profile_memory=True) as prof:
        std_time = timeit.timeit(lambda: std_rnn(src), number=timeit_count)/timeit_count
        print(f"std_fw: {std_time:.6}s")
    print(prof.key_averages(group_by_input_shape=True).table(sort_by='self_cpu_time_total', row_limit=100))

    split_rnn = SplitRNN(dropout, bidirectional).to(device)
    with profiler.profile(use_cuda=True, with_stack=True, profile_memory=True) as prof:
        split_time = timeit.timeit(lambda: split_rnn(src), number=timeit_count)/timeit_count
        print(f"split_fw: {split_time:.6}s")
    print(prof.key_averages(group_by_input_shape=True).table(sort_by='self_cpu_time_total', row_limit=100))

Expected behavior

I expect the standard stacked RNN to be faster than a manually written stack version.

Environment

PyTorch version: 1.7.1 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.4 LTS (x86_64) GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect

Python version: 3.7 (64-bit runtime) Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: GeForce RTX 2080 Ti GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 450.36.06 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A

Versions of relevant libraries: [pip] botorch==0.3.3 [pip] gpytorch==1.3.0 [pip] numpy==1.18.5 [pip] torch==1.7.1 [pip] torchvision==0.7.0 [conda] blas 1.0 mkl
[conda] botorch 0.3.3 pypi_0 pypi [conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] gpytorch 1.3.0 pypi_0 pypi [conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.18.5 py37ha1c710e_0
[conda] numpy-base 1.18.5 py37hde5b4d6_0
[conda] torch 1.7.1 pypi_0 pypi [conda] torchvision 0.7.0 py37_cu101 pytorch

Additional context

I started by asking for help here on Pytorch forums Before thinking of it as a bug.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @csarofeen @ptrblck @xwang233

zou3519 commented 3 years ago

I haven't looked at the code but the 2x slowdown seems pretty bad. I think this should be high-pri to investigate.

xwang233 commented 3 years ago

Thanks for the report. Is this a pytorch 1.7.1 + cuda 10.2 wheel downloaded with pip? Can you try cuda 11.0 wheel and see if the performance difference still exists?

gaelm commented 3 years ago

I have tried pytorch 1.7.1 and cuda 11.0 and the issue still exists.

gaelm commented 3 years ago

quick update: same issue on pytorch 1.8.0 and cuda 11.0

xwang233 commented 3 years ago

Thanks for the report. I'm able to reproduce it. I have reported it to the cudnn team.