google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Docstring for State is wrong about leaf types (or else the FSDP example is wrong) #4342

Closed gabbard closed 3 weeks ago

gabbard commented 4 weeks ago

System information

Problem you have encountered:

In the example for how to do Fully Sharded Data Parallelism (FSDP), we do:

    state = nnx.state(optimizer)
    ...
    def get_named_shardings(path: tuple, value: nnx.VariableState):
    if path[0] == 'params':
        ret = value.replace(NamedSharding(mesh, P(*value.sharding)))
        return ret
    elif path[0] == 'momentum':
        # currently the same as above but in general it could be different
        return value.replace(NamedSharding(mesh, P(*value.sharding)))
    else:
        raise ValueError(f'Unknown path: {path}')

    named_shardings = state.map(get_named_shardings)
    sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
    nnx.update(optimizer, sharded_state)

The code implies that state has has a key type str and value type of VariableState (which the debugger confirms). But the docstring of State says:

    A pytree-like structure that contains a ``Mapping`` from strings or
    integers to leaves. A valid leaf type is either :class:`Variable`,
    ``jax.Array``, ``numpy.ndarray`` or nested ``State``'s....

So having a VariableState as a leaf value seems at odds with the docstring.

To avoid confusion, either the docstring on State should be updated, or the FSDP example should be updated.

Also, a small nit: the sharded_state defined on line 107 of the FSDP example is unused. This isn't a big deal in itself, but it creates a doubt for the reader about the correctness of the example.

cgarciae commented 3 weeks ago

Hey @gabbard, thanks for reporting this! The docstrings are outdated as we no longer treat jax.Array and np.ndarray as State leaves.