patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.24k stars 63 forks source link

PRNGKeyArray doesn't support new (experimental) RBG PRNG impl #252

Closed csuter closed 1 month ago

csuter commented 2 months ago

When jax.config.update('jax_default_prng_impl', 'unsafe_rbg'), the value returned by jax.random.key(N) will (effectively) be a UInt32[4] (it's actually a scalar-shaped array with dtype("<urbg>")). This causes type errors when annotating PRNG key object with jt.PRNGKeyArray, which expects only scalar u32 or u32[2].

Here's a colab showing the jax behavior with unsafe_rbg: https://colab.research.google.com/drive/1FQtKBC8jCj2f8s3uABhFPuDjlI3Kz2AC?authuser=0#scrollTo=JU9QzhSxQN7T

c.f. also

I think the fix would simply be adding another union to the current PRNGKeyArray definition covering the new case. Happy to send a PR if that's welcome.

csuter commented 2 months ago

Oh perhaps I misunderstood and the "new-style...dtype key" is meant to cover this case. I'll have to take a closer look at this on my end. Any insights would be welcome, though. I definitely did get a (runtime) type error from such a setup and it went away when I (in my code) typed it as a jt.PRNGKeyArray | UInt32[Array, "4"].

patrick-kidger commented 2 months ago

Hmm, this works for me:

from typing import get_args

import jax
import jax.random as jr
from jaxtyping import Array, Key, PRNGKeyArray

jax.config.update('jax_default_prng_impl', 'unsafe_rbg')

x = jr.key(0)

print(isinstance(x, get_args(PRNGKeyArray)))  # True
print(isinstance(x, Key[Array, ""]))  # True

This is using JAX 0.4.33 and jaxtyping 0.2.34.

csuter commented 1 month ago

I'm going to close this for now, to keep the tracker tidy, under the assumption that I was doing something silly. I'll update here if I learn anything helpful!