patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Access to all hidden states in recurrent neural networks #855

Closed qtomcatq closed 2 months ago

qtomcatq commented 2 months ago

I've recently started using Equinox/Diffrax frameworks, and my objective it to translate my Pytorch code to JAX for higher performance.

Often it's important to access all hidden states of recurrent neural networks (GRU, LSTM), which is quite straighforward to do in Pytorch, however the implementation of recurrent neural networks in Equinox allows to do it only through lax.scan function. I've tried to access all hidden states in RNN (in example file train_rnn.ipynb), and made version of RNN that outputs results based on linear combination of all hidden states (RNN2):

#original RNN
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jr.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

#new RNN    
class RNN2(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    fc: eqx.nn.Linear 
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, time_dim, key, **kwargs):
        super().__init__(**kwargs)
        ckey, lkey, pkey = jr.split(key,3)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(time_dim, 1,  key=lkey)
        self.fc = eqx.nn.Linear(hidden_size, 1, key=pkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, ys):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            h = self.cell(inp,carry)
            return h, h

        out, fhidden = jax.lax.scan(f, hidden, ys)        

        fhh = self.linear(fhidden)
        #fhh = jax.nn.softmax(fhh)
        return jax.nn.sigmoid(self.fc(fhh[1])+self.bias)  

When I go through the training of RNN2, the weights of self.linear are not being updated at all, and the final accuracy is 0.5. This is quite confusing because Pytorch version of similar code works without any problem.

I would appreciate any help with this problem.

lockwo commented 2 months ago

I think it's just a problem with the shapes. You have batches of vectors you are directly applying Linear layers to (which work with vectors), so with a little vmap, I get 100% accuracy:

class RNN2(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    fc: eqx.nn.Linear 
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, key):
        ckey, lkey, pkey = jax.random.split(key,3)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        time_dim = 16
        self.linear = eqx.nn.Linear(time_dim, 1,  key=lkey)
        self.fc = eqx.nn.Linear(hidden_size, 1, key=pkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, ys):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            h = self.cell(inp,carry)
            return h, h

        out, fhidden = jax.lax.scan(f, hidden, ys)        

        print(fhidden.shape)
        fhh = jax.vmap(self.linear)(fhidden).squeeze(axis=-1)
        print(fhh.shape)
        #fhh = jax.nn.softmax(fhh)
        return jax.nn.sigmoid(self.fc(fhh)+self.bias)  
qtomcatq commented 2 months ago

thank you, @lockwo, I think the problem was the usage of fhh[1] instead of squeeze, even the values are the same (at least jax debugger tells this) the behavior is completely different. If this line:

return jax.nn.sigmoid(self.fc(fhh[1])+self.bias)

is changed to this:

return jax.nn.sigmoid(self.fc(fhh.squeeze(axis=0))+self.bias)

it works fine.

The code that you suggested also works for certain conditions, however jax.vmap(self.linear)(fhidden).squeeze(axis=-1) doesn't act on temporal dimension, but on hidden dimension. If hidden_size != time_dime it will result in error.