Open EtaoinWu opened 1 month ago
So the issue here is actually Accumulator.__init__
. When you have the lone @beartype
then this is adding type-checks to the __init__
method, and these are what are producing the n=3
binding.
Unfortunately, jaxtyping+beartype+Self just isn't really a supported combination right now.
So the issue here is actually
Accumulator.__init__
.
Interestingly it is not. If we comment out the line with acc1 = acc1.add(x)
,
# @jaxtyped(typechecker=beartype) # With or without this line
@beartype
class Accumulator(eqx.Module):
...
@jaxtyped(typechecker=beartype)
def test_accumulator():
x = jnp.ones(3)
y = jnp.ones(4)
acc1 = Accumulator(x)
# acc1 = acc1.add(x)
acc2 = Accumulator(y)
acc2 = acc2.add(y)
return acc1, acc2
test_accumulator()
This actually doesn't raise any error.
I recently come up with the following code:
Now, when running this code,
jaxtyped
complained in aUserWarning
saying that it prefers the@jaxtyped(typechecker=beartype)
syntax. (This warning was added before beartype's__instancecheck_str__
pseudostandard was implemented.) However, in this context, such syntax will lead to an error bybeartype
, because it lacks the context to figure out whattyping.Self
refers to. Therefore the code above is the only way to get it running.However, this
Accumulator
faces an issue: If you writeIn calling
acc2.add(y)
, it seems thatn=3
is still in the memo from the previousacc1.add(x)
call, and a type check errorBeartypeCallHintParamViolation
will be raised.So, my question is: how do one properly type-annotate this kind of class?