Open chenmoneygithub opened 1 year ago
We should disable jit_compile in auto mode if backend is TF and there's a LSTM or GRU layer and cuDNN is usable for them
@fchollet Yea, this should be a simple fix. We can probably add a util maybe_disable_xla()
, and further disabling situation could be covered by the util.
I will benchmark RNN with XLA off, but it would make JAX comparison a bit odd.
Any updates on this? I just want to train a model with a RNN layer and I get:
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_14684[] on XLA_GPU_JIT: CudnnRNNV3 (No registered 'CudnnRNNV3' OpKernel for XLA_GPU_JIT devices compatible with node {{node time_hvae_1/encoder_1/bidirectional_1/forward_lstm_1/CudnnRNNV3}}){{node time_hvae_1/encoder_1/bidirectional_1/forward_lstm_1/CudnnRNNV3}}
I´m able to execute this with JAX and pytorch but the training is extremely slow (especially in pytorch).
I really appreciate your work but every time I want to do some quick experiment, especially with time series, I end up having library related issues :(
RNN cannot be jit compiled, see error below:
Nothing we can really do, but open this issue for tracking and for reference purpose.