Open NeilGirdhar opened 7 months ago
@cgarciae
Maybe should raise a nice error message in custom __jax_array__
telling users to use .fork
? We should think about this a little more.
Sounds great. Here's another use case that could use a better error:
import numpy as np
from flax.experimental import nnx
from jax import float0, jvp
from jax.random import normal
def f(rngs):
k = rngs['blah']()
return normal(k)
def zero_rngs(rngs: nnx.Rngs) -> nnx.Rngs:
return nnx.Rngs({key: nnx.RngStream(np.empty(stream.key.shape, dtype=float0), stream.counts)
for key, stream in rngs._rngs.items()})
rngs = nnx.Rngs(blah=0)
print(jvp(f, (rngs,), (zero_rngs(rngs),)))
Also, accidentally passing Rngs
for a dict[str, RngStream]
parameter leads to a terrible error:
File "/home/neil/src/cmm/cmm/basic_module/noisy_mlp.py", line 38, in noise_callback
+ (normal(rngs['inference'](), layer_value.shape, layer_value.dtype)
^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 148, in <lambda>
return lambda: self._make_rng(name)
^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 145, in _make_rng
return stream.make_rng()
^^^^^^^^^^^^^^^^^
File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 148, in <lambda>
return lambda: self._make_rng(name)
^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 140, in _make_rng
raise ValueError(f"No RNG named {name!r} or 'default' found in Rngs.")
I think it may be worth thinking about whether you really want to have __getattr__ = __getitem__
. The mapping interface may be a few characters extra, but I think it's the usual Python way of doing things whereas this __getitem__
magic is causing some weird errors?
Yeah, I think we need a custom implementation for __getattr__
so we can raise an AttributeError
s instead of a KeyError
when needed.
Even with a custom implementation, the return type will be invisible to type checkers and linters whereas __getitem__
has the correct annotation. Also, if anyone inherits from Rngs
, then this __getattr__
can easily cause problems, right?