patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Incompatibility with flax.linen.tabulate #209

Closed kvablack closed 1 month ago

kvablack commented 1 month ago

It seems like flax.linen.tabulate uses some weird array representation internally, which causes errors like this:

jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Observation.
The problem arose whilst typechecking parameter 'images'.
Actual value: {
  'base_0_rgb':
  _ArrayRepresentation(shape=(2, 60, 80, 3), dtype=dtype('float32')),
  'left_wrist_0_rgb':
  _ArrayRepresentation(shape=(2, 60, 80, 3), dtype=dtype('float32'))
}
Expected type: dict[str, Float[Array, 'b _h _w c']].

One solution would be to globally disable type-checking during the tabulate call. However, the only way I found to do this is using typeguard.suppress_type_checks(), which was added in version 3.

Any suggestions?

patrick-kidger commented 1 month ago

A first recourse would be to report this as a bug in Flax. I suspect they'd be amenable to a PR fixing this.

An alternative option is to use Equinox instead, of course!

How are you applying jaxtyping's typechecking? For the more general problem of skipping certain modules then usually the import hook will only add those modules you explicitly ask it to typecheck. (Where you can list your own modules and not list third-party modules.)

kvablack commented 1 month ago

An alternative option is to use Equinox instead, of course!

If only it were that easy, migrating the entire codebase to Equinox is a bit out of scope unfortunately 😅

Upon further inspection, it looks like this only happens when typechecking a flax.struct.dataclass.

I'm using the decorator syntax to only typecheck certain functions and dataclasses, so I'm not sure how the import hook would help. The problem is that jaxtyping is working great most of the time (during init/apply/etc), and I would really like to typecheck my dataclasses at those times! It just breaks during tabulate. Are there any other ways that you know of to disable typechecking just for that one call?

I also tried defining Array = jax.Array | flax.linen.summary._ArrayRepresentation and using that type annotation instead, to no avail.

patrick-kidger commented 1 month ago

Ah sorry, I had to go look up the source code for tabulate to understand what you're referring to. IIUC they're actually initialising the dataclass for you, but with the arguments of the wrong type?

For what it's worth, jaxtyping (for historical reasons) actually checks the attributes of the class after initialisation, not the arguments to __init__. When you report things not working above, it might be that you're doing this for the arguments of a custom __init__ despite it being the types of the fields that you'd need to change? (To e.g. Float[Array, "foo bar"] | flax.linen.summary._ArrayRepresentation.)

Other than that jaxtyping does offer an opt-out flag that you could set for the duration of the call: https://github.com/patrick-kidger/jaxtyping/blob/49cf97ec84d5d8b5d6571582b120b9ab2358689c/jaxtyping/_config.py#L28

kvablack commented 1 month ago

Ah thanks, the opt-out flag is the escape hatch I'm looking for!

It seems like the reason jax.Array | flax.linen.summary._ArrayRepresentation didn't work, though, is a separate issue with Unions. Let me open a separate issue.