Using jax_getattr blows up if the following conditions are true:
It is invoked inside of a function that is jax.lax.scan'd over (and probably other loops)
The retrieved attribute is a PyTree with a number of leaves not equal to 1
Reproducer (pass --bug to trigger the problem):
--jit and --no-jit both fail, but slightly differently.
import jax
from jax import lax
from jax.experimental.attrs import jax_getattr
import argparse
class C:
def __init__(self):
self.vals = dict(x=0)
state = C()
def f(y):
v = jax_getattr(state, "vals")
return v["x"] + y
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--jit", default=False, action="store_true")
parser.add_argument("--no-jit", action="store_false", dest="jit")
parser.add_argument("--bug", default=False, action="store_true")
args = parser.parse_args()
def f_iter(i, v):
return f(v)
def do_loop():
return lax.fori_loop(0, 5, f_iter, 0)
if args.jit:
do_loop = jax.jit(do_loop)
if args.bug:
state.vals = dict(x=1, z=2)
else:
state.vals = dict(x=1)
print(f"running with {state.vals=}")
print(f"{do_loop()=}")
if __name__ == "__main__":
main()
The proximate issue seems to be a confusion in loops.py about whether the leaves of tracked PyTrees are flattened or not, but I haven't worked through the details.
e.g. this code expects flattening, but perhaps this caller isn't aware of that?
System info (python version, jaxlib version, accelerator, etc.)
Description
Using
jax_getattr
blows up if the following conditions are true:jax.lax.scan
'd over (and probably other loops)Reproducer (pass
--bug
to trigger the problem):--jit
and--no-jit
both fail, but slightly differently.The proximate issue seems to be a confusion in
loops.py
about whether the leaves of tracked PyTrees are flattened or not, but I haven't worked through the details. e.g. this code expects flattening, but perhaps this caller isn't aware of that?System info (python version, jaxlib version, accelerator, etc.)
I can't easily test a newer nightly right now, but the relevant code looks unchanged from a quick inspection.