patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.23k stars 63 forks source link

Stateful Equinox Module: how to annotate? #253

Open EtaoinWu opened 1 month ago

EtaoinWu commented 1 month ago

I recently come up with the following code:

from typing import Self

import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import Array, Float, jaxtyped

@jaxtyped(typechecker=beartype) # to typecheck __init__
@beartype
class Accumulator(eqx.Module):
    x: Float[Array, " n"]

    @jaxtyped
    def add(self, y: Float[Array, " n"]) -> Self:
        return self.__class__(self.x + y)

Now, when running this code, jaxtyped complained in a UserWarning saying that it prefers the @jaxtyped(typechecker=beartype) syntax. (This warning was added before beartype's __instancecheck_str__ pseudostandard was implemented.) However, in this context, such syntax will lead to an error by beartype, because it lacks the context to figure out what typing.Self refers to. Therefore the code above is the only way to get it running.

However, this Accumulator faces an issue: If you write

@jaxtyped(typechecker=beartype)
def test_accumulator():
    x = jnp.ones(3)
    y = jnp.ones(4)
    acc1 = Accumulator(x)
    acc1 = acc1.add(x)
    acc2 = Accumulator(y)
    acc2 = acc2.add(y)
    return acc1, acc2

In calling acc2.add(y), it seems that n=3 is still in the memo from the previous acc1.add(x) call, and a type check error BeartypeCallHintParamViolation will be raised.

So, my question is: how do one properly type-annotate this kind of class?

patrick-kidger commented 1 month ago

So the issue here is actually Accumulator.__init__. When you have the lone @beartype then this is adding type-checks to the __init__ method, and these are what are producing the n=3 binding.

Unfortunately, jaxtyping+beartype+Self just isn't really a supported combination right now.

EtaoinWu commented 1 month ago

So the issue here is actually Accumulator.__init__.

Interestingly it is not. If we comment out the line with acc1 = acc1.add(x),

# @jaxtyped(typechecker=beartype) # With or without this line
@beartype
class Accumulator(eqx.Module):
    ...

@jaxtyped(typechecker=beartype)
def test_accumulator():
    x = jnp.ones(3)
    y = jnp.ones(4)
    acc1 = Accumulator(x)
    # acc1 = acc1.add(x)
    acc2 = Accumulator(y)
    acc2 = acc2.add(y)
    return acc1, acc2

test_accumulator()

This actually doesn't raise any error.