google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.16k stars 650 forks source link

SPMD for initializing model using nnx.jit #4129

Open mmorinag127 opened 3 months ago

mmorinag127 commented 3 months ago

I would like to ask an example for initializing and updating a model using nnx.jit with SPMD. Is there any relevant example?

cgarciae commented 3 months ago

nnx.jit's API is changing very soon (#3963) according to #4107. Currently you could use nnx.iter_graph to traverse the entire model and use with jax.device_put to shard the arrays inside Variables.

with path, parent in nnx.iter_graph(model):
  if isinstance(parent, nnx.Module):
    for name, value in vars(parent).items():
      if isinstance(value, nnx.Variable):
        sharding = ...
        value.value = jax.device_put(value.value, sharding)
mmorinag127 commented 1 month ago

Hello,

I have some code examples for this purpose that would cause an error like ValueError: Mismatch custom node data: NodeDef(...

Could you please help me?

def _create_state():
    model = Model(...)
    tx = optax.adam(...)
    wrt = nnx.All(nnx.Param, nnx.Everything())
    optimizer = nnx.Optimizer(model, tx, wrt=wrt)
    graphdef, state = nnx.split((model, optimizer))
    return graphdef, state

with mesh:
    abst = nnx.eval_shape(_create_state)
    shardings = nnx.get_named_sharding(abst, mesh)
    graphdef, state = nnx.jit(_create_state, out_shardings=shardings)()