keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

[Proposal] API tuning for use_cudnn in LSTM/other layers #19647

Open LukeWood opened 2 weeks ago

LukeWood commented 2 weeks ago

As it stands right now, the use_cudnn argument is a bit difficult to use. In particular, there are a few things for me that stick out to me.

They are loosely~

1.) use_cudnn as a backend function is very context dependent - it requires your keras code to be aware of your installation/backend/etc, while being part of the high level API. I think that perhaps there should be a value for use_cudnn such as use_cudnn="require-compatible-configuratuin" that performs input validation to make sure the criterion for LSTM/GRU/etc are all met to use cudnn, but does NOT throw if cudnn is not available for environment related reasons (perhaps it should warn instead?).

Basically, this is nice because it lets you write portable code, while not risking leaving performance on the table in cudnn-compatible environments.

2.) i'm struggling a bit to get keras use_cudnn=True to work. The reason for my confusion is that the error messages don't quite follow the typical keras style -- in particular no context is included. I propose to alleviate this we always provide a reason why cudnn was not used. Maybe its because tf executing eagerly was disabled, maybe its because you're using the numpy backend - maybe your configuration is wrong. Whatever the case, I think we could polish up the error messages here to be self-correcting and save future users some round trips.


Happy to help out with contributing as always.

After some feedback on 1.) I'd be happy to try my hand at adding it.

Also happy to start contributing cases to 2.) if you are open to that.

Cheers!

fchollet commented 2 weeks ago

Right now you're basically not supposed to pass use_cudnn in the general case -- just pass nothing, and cuDNN will be used if it's feasible.

Passing use_cudnn=True is something you'd only do if you want to error out when cuDNN isn't available, as opposed to falling back to a non-cuDNN implementation.

I propose to alleviate this we always provide a reason why cudnn was not used.

Definitely, this is a good call. There are multiple user-facing factors at play, like recurrent dropout, masking (which will actually give you a targeted error message), activations, etc. And of course availability of cuDNN in your environment.

LukeWood commented 2 weeks ago

Right now you're basically not supposed to pass use_cudnn in the general case -- just pass nothing, and cuDNN will be used if it's feasible.

I guess what I'd be hoping to gain from an additional API here would be to express in code:

"I want to make sure that my Keras code is not the reason that we aren't using cudnn.
 If the backend/environment doesnt support it, that is alright - 
but I want to make sure I do not accidentally have the flags misconfigured"

though, perhaps just improving the error messages is sufficient

hertschuh commented 2 weeks ago

Hey @LukeWood ,

I want to make sure that my Keras code is not the reason that we aren't using cudnn

That's how I look at the use_cudnn argument. It really means "fail if you can't use cudnn".

though, perhaps just improving the error messages is sufficient

I think that's the right call. Are you willing to contribute on this?

LukeWood commented 2 weeks ago

yes happy to contribute!

It really means "fail if you can't use cudnn".

I think the key difference is that to me there are actually two intentions users may want to express. One is "fail if you can't use cudnn, regardless of cause" - one is "fail if I can't use cudnn, due to my own configuration of the RNN layer". I think they're nuancedly different, but that the difference is actually meaningful. Perhaps what I mean here would be clarified via code, but I think there is a meaningful difference between these two intentions.


Regardless, happy to contribute error message improvements.

fchollet commented 2 weeks ago

Regardless, happy to contribute error message improvements.

Sure, please open a PR if you're able. We probably need to make two backend functions, is_config_cudnn_compatible and is_cudnn_available -- or maybe pass appropriate info in the NotImplementedError exception message.