google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.8k stars 614 forks source link

Bad interaction between nnx.Rngs and custom derivatives #3528

Open NeilGirdhar opened 7 months ago

NeilGirdhar commented 7 months ago
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.11/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.11/lib/python3.11/site-packages/jax/_src/custom_derivatives.py", line 613, in __call__
    in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.11/lib/python3.11/site-packages/jax/_src/custom_derivatives.py", line 613, in <listcomp>
    in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
                                     ^^^^^^^^^^^^^^^^
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.11/lib/python3.11/site-packages/jax/_src/core.py", line 1431, in get_aval
    return concrete_aval(x)
           ^^^^^^^^^^^^^^^^
  File "/home/neil/.cache/pypoetry/virtualenvs/cmm-tspD8tmv-py3.11/lib/python3.11/site-packages/jax/_src/core.py", line 1422, in concrete_aval
    return concrete_aval(x.__jax_array__())
                         ^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 148, in <lambda>
    return lambda: self._make_rng(name)
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 140, in _make_rng
    raise ValueError(f"No RNG named {name!r} or 'default' found in Rngs.")
ValueError: No RNG named '__jax_array__' or 'default' found in Rngs.
NeilGirdhar commented 7 months ago

@cgarciae

cgarciae commented 7 months ago

Maybe should raise a nice error message in custom __jax_array__ telling users to use .fork? We should think about this a little more.

NeilGirdhar commented 7 months ago

Sounds great. Here's another use case that could use a better error:

import numpy as np
from flax.experimental import nnx
from jax import float0, jvp
from jax.random import normal

def f(rngs):
    k = rngs['blah']()
    return normal(k)

def zero_rngs(rngs: nnx.Rngs) -> nnx.Rngs:
    return nnx.Rngs({key: nnx.RngStream(np.empty(stream.key.shape, dtype=float0), stream.counts)
                     for key, stream in rngs._rngs.items()})

rngs = nnx.Rngs(blah=0)
print(jvp(f, (rngs,), (zero_rngs(rngs),)))
NeilGirdhar commented 7 months ago

Also, accidentally passing Rngs for a dict[str, RngStream] parameter leads to a terrible error:

  File "/home/neil/src/cmm/cmm/basic_module/noisy_mlp.py", line 38, in noise_callback
    + (normal(rngs['inference'](), layer_value.shape, layer_value.dtype)
              ^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 148, in <lambda>
    return lambda: self._make_rng(name)
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 145, in _make_rng
    return stream.make_rng()
           ^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 148, in <lambda>
    return lambda: self._make_rng(name)
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/src/flax/flax/experimental/nnx/nnx/rnglib.py", line 140, in _make_rng
    raise ValueError(f"No RNG named {name!r} or 'default' found in Rngs.")

I think it may be worth thinking about whether you really want to have __getattr__ = __getitem__. The mapping interface may be a few characters extra, but I think it's the usual Python way of doing things whereas this __getitem__ magic is causing some weird errors?

cgarciae commented 7 months ago

Yeah, I think we need a custom implementation for __getattr__ so we can raise an AttributeErrors instead of a KeyError when needed.

NeilGirdhar commented 7 months ago

Even with a custom implementation, the return type will be invisible to type checkers and linters whereas __getitem__ has the correct annotation. Also, if anyone inherits from Rngs, then this __getattr__ can easily cause problems, right?