google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

lstm error #4032

Open layssi opened 1 week ago

layssi commented 1 week ago

import flax.linen as nn import jax, jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (2, 3)) layer = nn.LSTMCell(features=4) carry = layer.initialize_carry(jax.random.key(1), x.shape) variables = layer.init(jax.random.key(2), carry, x) new_carry, out = layer.apply(variables, carry, x)

Running the code gives this error. This code comes from the documentation

flax.errors.AssignSubModuleError: Submodule LSTMCell must be defined in setup() or in a method wrapped in @compact (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.AssignSubModuleError)

rajasekharporeddy commented 1 day ago

Hi @layssi

I tested the mentioned code on colab CPU, GPU and TPU v2 with JAX 0.4.26 and Flax 0.8.4. Also tested on Macbook CPU with JAX version 0.4.30 and Flax 0.8.5. I could not reproduce the error that you mentioned and it works fine.

Please find the gist for reference.

Could you please verify if the issue persists with latest versions of JAX and Flax?

Thank you.