keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.29k stars 19.38k forks source link

RNN not compatible with XLA (TF backend) #18456

Open chenmoneygithub opened 1 year ago

chenmoneygithub commented 1 year ago

RNN cannot be jit compiled, see error below:

Detected unsupported operations when trying to compile graph __inference_one_step_on_data_993[] on XLA_GPU_JIT: CudnnRNN (No registered 'CudnnRNN' OpKernel for XLA_GPU_JIT devices compatible with node {{node CudnnRNN}}){{node CudnnRNN}}

Nothing we can really do, but open this issue for tracking and for reference purpose.

fchollet commented 1 year ago

We should disable jit_compile in auto mode if backend is TF and there's a LSTM or GRU layer and cuDNN is usable for them

chenmoneygithub commented 1 year ago

@fchollet Yea, this should be a simple fix. We can probably add a util maybe_disable_xla(), and further disabling situation could be covered by the util.

I will benchmark RNN with XLA off, but it would make JAX comparison a bit odd.

NahuelCostaCortez commented 4 days ago

Any updates on this? I just want to train a model with a RNN layer and I get: Detected unsupported operations when trying to compile graph __inference_one_step_on_data_14684[] on XLA_GPU_JIT: CudnnRNNV3 (No registered 'CudnnRNNV3' OpKernel for XLA_GPU_JIT devices compatible with node {{node time_hvae_1/encoder_1/bidirectional_1/forward_lstm_1/CudnnRNNV3}}){{node time_hvae_1/encoder_1/bidirectional_1/forward_lstm_1/CudnnRNNV3}}

I´m able to execute this with JAX and pytorch but the training is extremely slow (especially in pytorch).

I really appreciate your work but every time I want to do some quick experiment, especially with time series, I end up having library related issues :(