Closed anh-tong closed 8 months ago
After inspecting the error trace like following (need to set JAX_TRACEBACK_FILTERING=off
), it seems like f1
, f2
, f3
are equinox.Module
, are not function.
File "/opt/conda/lib/python3.10/site-packages/jax/_src/lax/control_flow/conditionals.py", line 139, in switch
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
File "/opt/conda/lib/python3.10/site-packages/jax/_src/util.py", line 263, in wrapper
return cached(config.config._trace_context(), *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/equinox/_module.py", line 839, in __hash__
return hash(tuple(jtu.tree_leaves(self)))
TypeError: unhashable type: 'ArrayImpl'
A quick fix will be like this:
import jax
import haliax as hax
Batch = hax.Axis("batch", 2)
In = hax.Axis("in", 10)
Out = hax.Axis("out", 5)
f1 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(0))
f2 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(1))
f3 = hax.nn.Linear.init(Out=Out, In=In, key=jax.random.PRNGKey(2))
x = hax.random.normal(key=jax.random.PRNGKey(3), shape=(Batch, In))
def f1_fn(x):
return f1(x)
def f2_fn(x):
return f2(x)
def f3_fn(x):
return f3(x)
index = 0 # either 0, 1, 2
jax.lax.switch(index, [f1_fn, f2_fn, f3_fn], x)
Again thank you for the library.
Glad you worked it out! A filter_switch
or something seems like it could
be a good addition to Equinox if you wanted to open an issue over there
(I'd be willing to implement if Patrick isn't opposed)
On Fri, Oct 27, 2023 at 6:17 AM Anh Tong @.***> wrote:
Closed #46 https://github.com/stanford-crfm/haliax/issues/46 as completed.
— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/46#event-10791689327, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLINPMIJTILBS7SZXTJ3YBOX5FAVCNFSM6AAAAAA6SYWWAWVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJQG44TCNRYHEZTENY . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Hi Haliax team,
I'm starting to learn Haliax, trying to work on one of my projects in Levanter. I wonder how to make a switch like
jax.lax.switch
in Haliax.The above code returns error
TypeError: unhashable type: 'ArrayImpl'
. Do you have any suggestion on this? Thank you!