Closed csuter closed 1 month ago
Oh perhaps I misunderstood and the "new-style...dtype keyjt.PRNGKeyArray | UInt32[Array, "4"]
.
Hmm, this works for me:
from typing import get_args
import jax
import jax.random as jr
from jaxtyping import Array, Key, PRNGKeyArray
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
x = jr.key(0)
print(isinstance(x, get_args(PRNGKeyArray))) # True
print(isinstance(x, Key[Array, ""])) # True
This is using JAX 0.4.33 and jaxtyping 0.2.34.
I'm going to close this for now, to keep the tracker tidy, under the assumption that I was doing something silly. I'll update here if I learn anything helpful!
When
jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
, the value returned by jax.random.key(N) will (effectively) be aUInt32[4]
(it's actually a scalar-shaped array withdtype("<urbg>")
). This causes type errors when annotating PRNG key object with jt.PRNGKeyArray, which expects only scalar u32 or u32[2].Here's a colab showing the jax behavior with unsafe_rbg: https://colab.research.google.com/drive/1FQtKBC8jCj2f8s3uABhFPuDjlI3Kz2AC?authuser=0#scrollTo=JU9QzhSxQN7T
c.f. also
I think the fix would simply be adding another union to the current PRNGKeyArray definition covering the new case. Happy to send a PR if that's welcome.