stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

`jax.lax.switch` for Haliax #46

Closed anh-tong closed 8 months ago

anh-tong commented 8 months ago

Hi Haliax team,

I'm starting to learn Haliax, trying to work on one of my projects in Levanter. I wonder how to make a switch like jax.lax.switch in Haliax.

import jax
import haliax as hax

Batch = hax.Axis("batch", 2)
In = hax.Axis("in", 10)
Out = hax.Axis("out", 5)

f1 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(0))
f2 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(1))
f3 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(2))

x = hax.random.normal(key=jax.random.PRNGKey(3), shape=(Batch, In))

index = 0 # either 0, 1, 2
jax.lax.switch(index, (f1, f2, f3), x)

The above code returns error TypeError: unhashable type: 'ArrayImpl'. Do you have any suggestion on this? Thank you!

anh-tong commented 8 months ago

After inspecting the error trace like following (need to set JAX_TRACEBACK_FILTERING=off), it seems like f1, f2, f3 are equinox.Module, are not function.

  File "/opt/conda/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 139, in switch
    jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
  File "/opt/conda/lib/python3.10/site-packages/jax/_src/util.py", line 263, in wrapper
    return cached(config.config._trace_context(), *args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/equinox/_module.py", line 839, in __hash__
    return hash(tuple(jtu.tree_leaves(self)))
TypeError: unhashable type: 'ArrayImpl'

A quick fix will be like this:

import jax
import haliax as hax

Batch = hax.Axis("batch", 2)
In = hax.Axis("in", 10)
Out = hax.Axis("out", 5)

f1 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(0))
f2 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(1))
f3 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(2))

x = hax.random.normal(key=jax.random.PRNGKey(3), shape=(Batch, In))

def f1_fn(x):
    return f1(x)

def f2_fn(x):
    return f2(x)

def f3_fn(x):
    return f3(x)

index = 0 # either 0, 1, 2
jax.lax.switch(index, [f1_fn, f2_fn, f3_fn], x)

Again thank you for the library.

dlwh commented 8 months ago

Glad you worked it out! A filter_switch or something seems like it could be a good addition to Equinox if you wanted to open an issue over there (I'd be willing to implement if Patrick isn't opposed)

On Fri, Oct 27, 2023 at 6:17 AM Anh Tong @.***> wrote:

Closed #46 https://github.com/stanford-crfm/haliax/issues/46 as completed.

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/46#event-10791689327, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLINPMIJTILBS7SZXTJ3YBOX5FAVCNFSM6AAAAAA6SYWWAWVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJQG44TCNRYHEZTENY . You are receiving this because you are subscribed to this thread.Message ID: @.***>