patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.11k stars 56 forks source link

bug: can't type flax.struct.dataclass with vmapped functions #177

Open Egiob opened 6 months ago

Egiob commented 6 months ago

Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.

import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array

@flax.struct.dataclass
class Data:
    a: Array

def f(x: Data) -> int:
    return 1

data = Data(a=jnp.ones(1, dtype=int))

jax.vmap(f)(data)

It raises the following error (with beartyping):

E jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Data. E The problem arose whilst typechecking argument 'a'. E Called with arguments: {'self': Data(...), 'a': <object object at 0x7fc7c87e8fc0>} E Parameter annotations: (self: Any, a: jax.Array).

Here are the versions I'm using:

flax==0.8.1 jax==0.4.21 jaxtyping==0.2.25

I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!

patrick-kidger commented 6 months ago

That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment?

Egiob commented 6 months ago

I see, my minimal reproduction was ambiguous. Sorry for that. I figured out that it depends on the order of the decorator @jaxtyped. This fails:

import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped

@jaxtyped(typechecker=beartype.beartype)
@flax.struct.dataclass
class Data:
    a: Array

def f(x: Data) -> int:
    return 1

data = Data(a=jnp.ones(10, dtype=int))

jax.vmap(f)(data)

This doesn't fail:

import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped

@flax.struct.dataclass
@jaxtyped(typechecker=beartype.beartype)
class Data:
    a: Array

def f(x: Data) -> int:
    return 1

data = Data(a=jnp.ones(10, dtype=int))

jax.vmap(f)(data)

I think the error occured for me because I used the pytest hook, that should add the jaxtyped decorator on top according to the docs.

Tell me if you can reproduce this :smile: (I have beartype==0.17.2 but I don't think it matters)

patrick-kidger commented 6 months ago

Ah, thank you!

It looks like this is a bug in Flax itself. Here's a MWE that doesn't use jaxtyping:

import flax
import jax.tree_util as jtu

@flax.struct.dataclass
class A:
    x: int

    def __init__(self):
        pass

leaves, treedef = jtu.tree_flatten(A())
jtu.tree_unflatten(treedef, leaves)

It looks like the reason for this is that their tree-unflattening rule is using the __init__ method for their type, which is a long-standing gotcha when using JAX.

Unsurprisingly, I'd recommend using Equinox instead :)