Open foxik opened 7 months ago
@foxik Thanks for the issue! Would you like to contribute this by modifying the following file? https://github.com/keras-team/keras/blob/master/keras/backend/torch/rnn.py#L377C1-L382C30
@haifeng-jin I am not sure I can do it correctly. I assume that
cudnn_ok
will probably need to consider also the current device (whether it is cuda or not) [in addition to verifying that the arguments are supported by CuDNN implementation)
cudnn_ok
is currently being called only for tensorflow backend; on the other hand, it is used only to set supports_jit
to False, which is probably not needed for PyTorch, because the sources indicate that TorchScript can compile torch.nn.LSTM/GRU
.torch.nn.LSTM/GRU
is a whole layer including parameters, but we need to use the given parameters. Therefore, we should probably call torch._VF.lstm/gru
, but I am not sure whether that would be considered OKgo_backwards
has no direct analogue in Torch API, so some manual reversing will be neededgo_backwards
would not be used in PyTorch for most usages, so it would not matter its implementation would not be greatIn any case, for the time being I unfortunately do not have time to work on this.
This feature would be great indeed. Hopefully someone high capable will attend to this sometime soon.
Hi,
are there any plans to add cuDNN-accelerated versions of LSTM and GRU to the PyTorch backend? Without cuDNN acceleration, the LSTM and GRU are considerably (several times) slower, even when running on GPU; however, we still use RNNs heavily (for example, adding them after Transformer encoder still helps in some cases).
The
torch.nn.LSTM
/torch.nn.GRU
offer cuDNN acceleration, and wrapping them to akeras.layers.Layer
works, but the resulting model is not backend-agnostic (so the resulting model cannot be used cross-frameworks).Thanks for consideration :pray: and cheers!
PS: Relatedly,
torch.nn.LSTM/GRU
offers bidirectional computation by a single call (by passingbidirectional=True
) -- I am not sure how much faster it is compared to two asynchronous unidirectional computations, but if it is faster,keras.layers.Bidirectional
would probably have to be updated to handlekeras.layers.LSTM
andkeras.layers.GRU
specifically to support it.