cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Module init fails at wrong key type #9

Closed kimbochen closed 2 years ago

kimbochen commented 2 years ago

Hi, thanks for the great work! I am trying to learn how to use JAX and treex, so I followed the tutorial.

class Linear(tx.Module):
    w: tx.Parameter[tx.Initializer, jnp.ndarray]
    b: tx.Parameter[jnp.ndarray]

    def __init__(self, din, dout):
        super().__init__()
        self.w = tx.Initializer(
            lambda key: jax.random.uniform(key, shape=(din, dout)))
        self.b = jnp.zeros(shape=(dout,))

    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

linear = Linear(3, 5).init(42)

However, I always get this assertion error.

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-7-4b5c9a5c519d> in <module>()
     12         return jnp.dot(x, self.w) + self.b
     13 
---> 14 linear = Linear(3, 5).init(42)

3 frames

/usr/local/lib/python3.7/dist-packages/treex/module.py in next_key()
     57         def next_key() -> jnp.ndarray:
     58             nonlocal key
---> 59             assert isinstance(key, jnp.ndarray)
     60             next_key, key = jax.random.split(key)
     61             return next_key

AssertionError: 

After digging into the code, I found out that jax.random.split(key) seems to return keys of type numpy.ndarray. Replacing jnp.ndarray with np.ndarray still creates problems: key is originally of type jaxlib.xla_extension.DeviceArray. I would love to make a PR, but I am not sure how to fix this. Here's a Colab notebook that replicates the issue.

lkhphuc commented 2 years ago

I open a PR here #10 that at least make the example run. Take a look if you're interested.

cgarciae commented 2 years ago

Hey @lkhphuc sorry for the late replay, somehow missed this issue in my inbox.

The issue is related to a recent change in jax (see google/jax#8017), this will be fixed in the next version of Treex. I think the easiest fix for you right now is to rollback to jax==0.2.20 as this issue also affects Flax as used by Treex.

cgarciae commented 2 years ago

@lkhphuc I believe this should be fixed in treex==0.5.0

kimbochen commented 2 years ago

Thanks for the replies, the problem is fixed.