Open layssi opened 1 week ago
Hi @layssi
I tested the mentioned code on colab CPU, GPU and TPU v2 with JAX 0.4.26 and Flax 0.8.4. Also tested on Macbook CPU with JAX version 0.4.30 and Flax 0.8.5. I could not reproduce the error that you mentioned and it works fine.
Please find the gist for reference.
Could you please verify if the issue persists with latest versions of JAX and Flax?
Thank you.
import flax.linen as nn import jax, jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (2, 3)) layer = nn.LSTMCell(features=4) carry = layer.initialize_carry(jax.random.key(1), x.shape) variables = layer.init(jax.random.key(2), carry, x) new_carry, out = layer.apply(variables, carry, x)
Running the code gives this error. This code comes from the documentation
flax.errors.AssignSubModuleError: Submodule LSTMCell must be defined in
setup()
or in a method wrapped in@compact
(https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)