pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.18k stars 22.09k forks source link

JITed GRU too slow #35998

Open usamec opened 4 years ago

usamec commented 4 years ago

🐛 Bug

It is advertised, that forward pass of JITed RNNs (e.g. GRU) is as fast as cuDNN implementation. But it is not the case.

To Reproduce

Steps to reproduce the behavior:

See here: https://gist.github.com/usamec/af21be7b83e6b1a3f38c26136af811f3

Expected behavior

Forward pass is as fast as cuDNN.

Environment

Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.0

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: TITAN Xp
Nvidia driver version: 430.50
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0

Versions of relevant libraries:
[pip3] numpy==1.15.0
[conda] _pytorch_select           0.2                       gpu_0  
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.0.130                      0  
[conda] mkl                       2019.4                      243  
[conda] mkl-service               2.3.0            py36he904b0f_0  
[conda] mkl_fft                   1.0.15           py36ha843d7b_0  
[conda] mkl_random                1.1.0            py36hd6b4f25_0  
[conda] numpy                     1.18.1           py36h4f9e942_0  
[conda] numpy-base                1.18.1           py36hde5b4d6_1  
[conda] pytorch                   1.4.0           py3.6_cuda10.0.130_cudnn7.6.3_0    pytorch
[conda] pytorch-qrnn              0.2.1                    pypi_0    pypi
[conda] torch-scatter             1.4.0                    pypi_0    pypi
[conda] torchvision               0.5.0                    pypi_0    pypi

I am also getting same slowdown on GeForce RTX 2080 Ti

Additional context

Posted here first. https://discuss.pytorch.org/t/jited-gru-too-slow/68873

Warm starting does not change much at all.

cc @ezyang @gchanan @zou3519 @suo

zdevito commented 4 years ago

@Krovatkin is looking into this. It appears to be a regression in the way we are getting the programs ready for being consumed by the fuser.

daniel-p-gonzalez commented 4 years ago

Any progress on this? I get extraordinarily slow performance on all custom RNNs despite all efforts to make it easier for the jit to do the right thing.

The performance when loading the model in C++ (on CPU) is extremely poor as well - in fact, I replaced the forward prop with my own implementation, and even the first pass naive loop I implemented was orders of magnitude faster. I've since optimized the C++ inference and am happy with the performance there, but the GPU training in python is still unbearably slow. Not sure how to work around it at this point, it's definitely grinding my progress to a halt. Thanks for any updates @zdevito & @Krovatkin.

munael commented 4 years ago

@zdevito

@Krovatkin is looking into this. It appears to be a regression in the way we are getting the programs ready for being consumed by the fuser.

Is there a version recent enough that we can fallback to for this?

ZolotukhinM commented 4 years ago

@FBMachine @narfanar The fuser got disabled in 1.5, but we're working on its replacement. In the meanwhile, as a workaround, could you please try the following command:

torch._C._jit_set_profiling_executor(False)

I would not recommend using it heavily if you can avoid it, because that mode will most probably be deprecated soon (and also it is an internal API that could be changed at any point). But maybe it unblocks you until the regression is fixed properly.

daniel-p-gonzalez commented 4 years ago

Hi @ZolotukhinM, thanks for the suggested workaround. It does indeed increase performance to some degree (about 3x). I've found a rather hacky workaround that works a bit better (12x) in exchange for some accuracy tradeoffs that I think I will stick with until the fuser is updated. Thanks, and I'm looking forward to the update!

Krovatkin commented 4 years ago

After loop peeling fwd performance is equivalent to the LE + OF but backward is still ~ 3x slower.

(bench) villedepommes@devfair0202:/scratch/villedepommes/pytorches/bench/benchmarks$ python gru.py --mode=te
gru_cudnn: 0.42766833305358887s
gru_cudnn_bwd: 0.6974105834960938s
gru_eager: 4.529504776000977s
gru_eager_bwd: 6.401309490203857s
compilation time:  1.709897756576538
gru_jit: 1.5193572044372559s
gru_jit_bwd: 6.728761434555054s
(bench) villedepommes@devfair0202:/scratch/villedepommes/pytorches/bench/benchmarks$ python gru.py --mode=le
gru_cudnn: 0.4282228946685791s
gru_cudnn_bwd: 0.699333906173706s
gru_eager: 4.915524005889893s
gru_eager_bwd: 7.002946853637695s
compilation time:  0.08745026588439941
gru_jit: 1.572685718536377s
gru_jit_bwd: 2.98301362991333s
Krovatkin commented 4 years ago

A duplicate of https://github.com/pytorch/pytorch/issues/37455 . Keeping it open just in case.