state = nnx.state(optimizer)
...
def get_named_shardings(path: tuple, value: nnx.VariableState):
if path[0] == 'params':
ret = value.replace(NamedSharding(mesh, P(*value.sharding)))
return ret
elif path[0] == 'momentum':
# currently the same as above but in general it could be different
return value.replace(NamedSharding(mesh, P(*value.sharding)))
else:
raise ValueError(f'Unknown path: {path}')
named_shardings = state.map(get_named_shardings)
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
nnx.update(optimizer, sharded_state)
The code implies that state has has a key type str and value type of VariableState (which the debugger confirms). But the docstring of State says:
A pytree-like structure that contains a ``Mapping`` from strings or
integers to leaves. A valid leaf type is either :class:`Variable`,
``jax.Array``, ``numpy.ndarray`` or nested ``State``'s....
So having a VariableState as a leaf value seems at odds with the docstring.
To avoid confusion, either the docstring on State should be updated, or the FSDP example should be updated.
System information
Problem you have encountered:
In the example for how to do Fully Sharded Data Parallelism (FSDP), we do:
The code implies that
state
has has a key typestr
and value type ofVariableState
(which the debugger confirms). But the docstring ofState
says:So having a
VariableState
as a leaf value seems at odds with the docstring.To avoid confusion, either the docstring on
State
should be updated, or the FSDP example should be updated.Also, a small nit: the
sharded_state
defined on line 107 of the FSDP example is unused. This isn't a big deal in itself, but it creates a doubt for the reader about the correctness of the example.