keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
62 stars 28 forks source link

Missing methods to easily access reset_state and states within keras.Model #61

Open mergian opened 1 year ago

mergian commented 1 year ago

System information.

TensorFlow version (you are using): 2.11.0 Are you willing to contribute it (Yes/No) : yes

Describe the feature and the current behavior/state.

Stateful RNN layers have the method layer.reset_state() and the states themselves can be fetched through layer.states. However, when I have a model consisting of many RNN layers, mixed with other layers, it becomes tedious to loop through all layers and reset them manually. So state-of-the-art is something like:

for l in model.layers:
    if hasattr(l, 'reset_state'):
        l.reset_state()

This becomes really combersome, when you use bidirectional RNNs, because then you need to also check if the layer has l.forward_layer and l.backward_layer and also reset the states in them.

Therefore my proposal is to add reset_state, get_states and set_states to keras.Model. The last two work similar to get_weights() and set_weights(). Possible implementation could be:

def reset_state(self):
    def reset_state(l):
        if hasattr(l, 'reset_state'):
            l.reset_state()
        if hasattr(l, 'forward_layer'):
            reset_state(l.forward_layer)
        if hasattr(l, 'backward_layer'):
            reset_state(l.backward_layer)

    for l in self.layers:
        reset_state(l)

def get_states(self):
    states = []
    def get_states(l):
        if hasattr(l, 'states'):
            lst += l.states
        if hasattr(l, 'forward_layer'):
            get_states(l.forward_layer)
        if hasattr(l, 'backward_layer'):
            get_states(l.backward_layer)

    for l in self.layers:
        get_states(l)

    return states

def set_states(self, states):
    it = iter(states)
    def set_states(l):
        if hasattr(l, 'states'):
            for s in l.states:
                s.assign(next(it))
        if hasattr(l, 'forward_layer'):
            set_states(l.forward_layer)
        if hasattr(l, 'backward_layer'):
            set_states(l.backward_layer)

    for l in self.layers:
        set_states(l)

Will this change the current api? How? Yes, it adds the methods reset_state, get_states and set_states to class keras.Model, so people don't need to loop through the Keras data structures.

Who will benefit from this feature? All people that use stateful RNN layers ;)

Contributing

SuryanarayanaY commented 1 year ago

@mergian , Thanks for your suggestion. Please feel free to raise PR for your idea.Thanks !

mergian commented 1 year ago

👍🏻 will start working on it now.

mergian commented 1 year ago

Just saw, that keras.layers.RNN.reset_states has a states argument (https://github.com/keras-team/keras/blob/master/keras/layers/rnn/base_rnn.py#L846), which can be used to set the states, so I think set_states will not be needed. What's your opinion on it?

SuryanarayanaY commented 1 year ago

Hi @mergian , I gone through this method reset_states() from Bidirectional RNN API document and the implementation for reset_states() mentioned there as below.

    def reset_states(self):
        self.forward_layer.reset_states()
        self.backward_layer.reset_states()

Wouldn't this solves your purpose? Please let us know.

mergian commented 1 year ago

I think I wasn't clear with the question. In my first post I proposed to add def set_states(self, states): to the keras.Model class. But then I saw that the RNN layers already has a def reset_states(self, states): method (see here), so I think it would be more consistent to add the following functions to keras.Model:

If that is OK, then I will prepare the pull request and we can discuss details later if necessary.

SuryanarayanaY commented 1 year ago

@mergian , I just want to give some context in case you are using the Bidirectional wrapper in the model. You can proceed to PR and we can discuss details later also.

ianstenbit commented 1 year ago

Removing the review-required label as we're taking a look at this in keras-team/keras#17468.

This issue can either be closed (fixed) or closed (won't fix) based on that PR.