Open mergian opened 1 year ago
@mergian , Thanks for your suggestion. Please feel free to raise PR for your idea.Thanks !
👍🏻 will start working on it now.
Just saw, that keras.layers.RNN.reset_states
has a states
argument (https://github.com/keras-team/keras/blob/master/keras/layers/rnn/base_rnn.py#L846), which can be used to set the states, so I think set_states
will not be needed. What's your opinion on it?
Hi @mergian ,
I gone through this method reset_states()
from Bidirectional RNN API document and the implementation for reset_states()
mentioned there as below.
def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
Wouldn't this solves your purpose? Please let us know.
I think I wasn't clear with the question. In my first post I proposed to add def set_states(self, states):
to the keras.Model
class. But then I saw that the RNN layers already has a def reset_states(self, states):
method (see here), so I think it would be more consistent to add the following functions to keras.Model
:
def reset_states(self, states=None)
: identical to keras.layers.RNN.reset_state(self, states=None)
def get_states(self)
: similar to keras.Model.get_weights(self)
but for the RNN states@property def states(self)
: similar to keras.Model.weights
but for the RNN statesIf that is OK, then I will prepare the pull request and we can discuss details later if necessary.
@mergian , I just want to give some context in case you are using the Bidirectional wrapper in the model. You can proceed to PR and we can discuss details later also.
Removing the review-required label as we're taking a look at this in keras-team/keras#17468.
This issue can either be closed (fixed) or closed (won't fix) based on that PR.
System information.
TensorFlow version (you are using): 2.11.0 Are you willing to contribute it (Yes/No) : yes
Describe the feature and the current behavior/state.
Stateful RNN layers have the method
layer.reset_state()
and the states themselves can be fetched throughlayer.states
. However, when I have a model consisting of many RNN layers, mixed with other layers, it becomes tedious to loop through all layers and reset them manually. So state-of-the-art is something like:This becomes really combersome, when you use bidirectional RNNs, because then you need to also check if the layer has
l.forward_layer
andl.backward_layer
and also reset the states in them.Therefore my proposal is to add
reset_state
,get_states
andset_states
tokeras.Model
. The last two work similar toget_weights()
andset_weights()
. Possible implementation could be:Will this change the current api? How? Yes, it adds the methods
reset_state
,get_states
andset_states
to classkeras.Model
, so people don't need to loop through the Keras data structures.Who will benefit from this feature? All people that use stateful RNN layers ;)
Contributing