danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

ENH: use jaxtyping #150

Closed nstarman closed 2 months ago

nstarman commented 5 months ago

If you're building on equinox, might as well use jaxtyping for more detailed type hints, including shape information.

danielward27 commented 5 months ago

Thanks for the suggestion and the motivation to start thinking about this. It would improve the robustness and user friendliness of the package. One example that I've fallen into a couple of times:

from flowjax.distributions import StandardNormal
normal = StandardNormal(shape=5)  # Type specified a tuple of ints but we pass an int
normal.log_prob(jnp.ones(5))  # Only errors here!

Definitely would be nice to avoid issues like this!

One slight annoyance is that given an object with a shape attribute, afaik there isn't a particularly natural way to utilize this for type checking directly in the objects methods (see https://github.com/patrick-kidger/jaxtyping/pull/140). Maybe something along the lines of

import equinox as eqx
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import jaxtyped, Float, Array

@jaxtyped
@typechecker
class SomeClass(eqx.Module):
    shape: tuple[int, ...]

    def __call__(self, arr):
        strshape = " ".join(map(str, self.shape))

        @jaxtyped
        @typechecker
        def _shape_check(arr: Float[Array, strshape]):
           return arr

        _shape_check(arr)
        return arr + 1

For bijections something like this might be able to be added to _unwrap_check_and_cast, and for distributions it would replace _check_shapes. Whilst something like the above might be possible, I'd also like to make sure the errors remain as clear as they are now.

Regardless, I think a good first step would be to leave the shape checking in bijections and distributions to the current wrappers flowjax uses, and just use jaxtyping to check functions/module instantiation. I'm happy to accept pull requests on this or I can take a look at some point over the coming couple of months.

danielward27 commented 4 months ago

I've started adding it here https://github.com/danielward27/flowjax/pull/154. The tests now pass whilst running beartype. I've added some shape annotations, although I'm sure more detail could be added in places.

I'll probably merge https://github.com/danielward27/flowjax/pull/154 as is as a starting point. However, one question is how far we want to enforce type checking (likely using beartype):

Let me know if you have any thoughts @nstarman or others

nstarman commented 4 months ago

@danielward27 #154 looks great!

I definitely agree that point 1 is the easiest, especially to start. Adding a few sentences to the docs + some links would suffice for people to then know how to enable runtime type checking via jaxtyping's import hook. Longer term I would recommend doing point 2. If you use an environment variable when setting up jaxtyping then users would still be able choose the type checker of their choice by setting the environment variable — e.g. FLOWJAX_RUNTIME_TYPECHECKER="beartype.beartype".

danielward27 commented 2 months ago

I'll close this for now, runtime type checking is great, but the options are not massively mature (also pythons typing system is a bit of a mess at the best of times), so it is probably better to avoid reliance on them at the moment. For now we use it for testing, and people can choose to opt in. Hopefully this is an issue to return to in the future when things have advanced a bit.