Closed qtomcatq closed 2 months ago
I think it's just a problem with the shapes. You have batches of vectors you are directly applying Linear layers to (which work with vectors), so with a little vmap, I get 100% accuracy:
class RNN2(eqx.Module):
hidden_size: int
cell: eqx.Module
linear: eqx.nn.Linear
fc: eqx.nn.Linear
bias: jax.Array
def __init__(self, in_size, out_size, hidden_size, key):
ckey, lkey, pkey = jax.random.split(key,3)
self.hidden_size = hidden_size
self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
time_dim = 16
self.linear = eqx.nn.Linear(time_dim, 1, key=lkey)
self.fc = eqx.nn.Linear(hidden_size, 1, key=pkey)
self.bias = jnp.zeros(out_size)
def __call__(self, ys):
hidden = jnp.zeros((self.hidden_size,))
def f(carry, inp):
h = self.cell(inp,carry)
return h, h
out, fhidden = jax.lax.scan(f, hidden, ys)
print(fhidden.shape)
fhh = jax.vmap(self.linear)(fhidden).squeeze(axis=-1)
print(fhh.shape)
#fhh = jax.nn.softmax(fhh)
return jax.nn.sigmoid(self.fc(fhh)+self.bias)
thank you, @lockwo, I think the problem was the usage of fhh[1] instead of squeeze, even the values are the same (at least jax debugger tells this) the behavior is completely different. If this line:
return jax.nn.sigmoid(self.fc(fhh[1])+self.bias)
is changed to this:
return jax.nn.sigmoid(self.fc(fhh.squeeze(axis=0))+self.bias)
it works fine.
The code that you suggested also works for certain conditions, however jax.vmap(self.linear)(fhidden).squeeze(axis=-1) doesn't act on temporal dimension, but on hidden dimension. If hidden_size != time_dime it will result in error.
I've recently started using Equinox/Diffrax frameworks, and my objective it to translate my Pytorch code to JAX for higher performance.
Often it's important to access all hidden states of recurrent neural networks (GRU, LSTM), which is quite straighforward to do in Pytorch, however the implementation of recurrent neural networks in Equinox allows to do it only through lax.scan function. I've tried to access all hidden states in RNN (in example file train_rnn.ipynb), and made version of RNN that outputs results based on linear combination of all hidden states (RNN2):
When I go through the training of RNN2, the weights of self.linear are not being updated at all, and the final accuracy is 0.5. This is quite confusing because Pytorch version of similar code works without any problem.
I would appreciate any help with this problem.