jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.45k stars 2.8k forks source link

Issue with `jax_getattr` inside `jax.scan` when the PyTree has multiple leaves #23782

Open nelhage opened 1 month ago

nelhage commented 1 month ago

Description

Using jax_getattr blows up if the following conditions are true:

Reproducer (pass --bug to trigger the problem): --jit and --no-jit both fail, but slightly differently.

import jax
from jax import lax
from jax.experimental.attrs import jax_getattr

import argparse

class C:
    def __init__(self):
        self.vals = dict(x=0)

state = C()

def f(y):
    v = jax_getattr(state, "vals")
    return v["x"] + y

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--jit", default=False, action="store_true")
    parser.add_argument("--no-jit", action="store_false", dest="jit")

    parser.add_argument("--bug", default=False, action="store_true")

    args = parser.parse_args()

    def f_iter(i, v):
        return f(v)

    def do_loop():
        return lax.fori_loop(0, 5, f_iter, 0)

    if args.jit:
        do_loop = jax.jit(do_loop)

    if args.bug:
        state.vals = dict(x=1, z=2)
    else:
        state.vals = dict(x=1)

    print(f"running with {state.vals=}")
    print(f"{do_loop()=}")

if __name__ == "__main__":
    main()

The proximate issue seems to be a confusion in loops.py about whether the leaves of tracked PyTrees are flattened or not, but I haven't worked through the details. e.g. this code expects flattening, but perhaps this caller isn't aware of that?

System info (python version, jaxlib version, accelerator, etc.)

❯ python -c 'import jax; jax.print_environment_info()'
jax:    0.4.31.dev20240722
jaxlib: 0.4.31.dev20240722
numpy:  1.24.4
python: 3.11.6 | packaged by conda-forge | (main, Oct  3 2023, 10:37:07) [Clang 15.0.7 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Nelson-Elhage-MacBook', release='23.6.0', version='Darwin Kernel Version 23.6.0: Wed Jul 31 20:49:46 PDT 2024; root:xnu-10063.141.1.700.5~1/RELEASE_ARM64_T8103', machine='arm64')

I can't easily test a newer nightly right now, but the relevant code looks unchanged from a quick inspection.

nelhage commented 1 month ago

Attaching output and stack traces from my machine for easier skimming and in case it doesn't reproduce somehow.