Open mmorinag127 opened 3 months ago
nnx.jit
's API is changing very soon (#3963) according to #4107. Currently you could use nnx.iter_graph
to traverse the entire model and use with jax.device_put
to shard the arrays inside Variables.
with path, parent in nnx.iter_graph(model):
if isinstance(parent, nnx.Module):
for name, value in vars(parent).items():
if isinstance(value, nnx.Variable):
sharding = ...
value.value = jax.device_put(value.value, sharding)
Hello,
I have some code examples for this purpose that would cause an error like ValueError: Mismatch custom node data: NodeDef(...
Could you please help me?
def _create_state():
model = Model(...)
tx = optax.adam(...)
wrt = nnx.All(nnx.Param, nnx.Everything())
optimizer = nnx.Optimizer(model, tx, wrt=wrt)
graphdef, state = nnx.split((model, optimizer))
return graphdef, state
with mesh:
abst = nnx.eval_shape(_create_state)
shardings = nnx.get_named_sharding(abst, mesh)
graphdef, state = nnx.jit(_create_state, out_shardings=shardings)()
I would like to ask an example for initializing and updating a model using nnx.jit with SPMD. Is there any relevant example?