Open Egiob opened 6 months ago
That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment?
I see, my minimal reproduction was ambiguous. Sorry for that. I figured out that it depends on the order of the decorator @jaxtyped. This fails:
import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
@flax.struct.dataclass
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data)
This doesn't fail:
import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@flax.struct.dataclass
@jaxtyped(typechecker=beartype.beartype)
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data)
I think the error occured for me because I used the pytest hook, that should add the jaxtyped decorator on top according to the docs.
Tell me if you can reproduce this :smile: (I have beartype==0.17.2
but I don't think it matters)
Ah, thank you!
It looks like this is a bug in Flax itself. Here's a MWE that doesn't use jaxtyping:
import flax
import jax.tree_util as jtu
@flax.struct.dataclass
class A:
x: int
def __init__(self):
pass
leaves, treedef = jtu.tree_flatten(A())
jtu.tree_unflatten(treedef, leaves)
It looks like the reason for this is that their tree-unflattening rule is using the __init__
method for their type, which is a long-standing gotcha when using JAX.
Unsurprisingly, I'd recommend using Equinox instead :)
Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.
It raises the following error (with beartyping):
Here are the versions I'm using:
I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!