google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 648 forks source link

Clarification on sharding strategy to combine multiple training steps together via nnx.scan #4417

Open Teculos opened 8 hours ago

Teculos commented 8 hours ago

I'm trying to, roughly, replicate behaviour found in this repo, where they pmap a scan transform on the train step to combine multiple train steps into one function call (see run_lib.py line 124). Since this is a pre flax.nnx implementation they replicate the model and pmap over the model replicates and data (structured as [combined steps, jax.device_count(), batchsize// jax.device_count(), *data dim]).

Ergo they pmap across the second dimension and scan across the first to distribute the forward pass across GPUs, jax.lax.pmean the gradient, update the model, and iterate to the nexted step in the scan.

Since pmap has no flax.nnx equivalent my approach was to shard the data across the batch dimension (data for me is in the shape [combined steps, batch_size, *data dim]) and replicate the model on each GPU to distribute the forward pass. Although I'm not certain if I'm going about it properly. See below for a minimum example with a simple model and random data/labels.

from flax import nnx
from jax.sharding import NamedSharding, PartitionSpec

import jax
import optax

#Data is of shape [steps, batch, data dim]
data = jax.random.normal(jax.random.PRNGKey(1), (5,100,20))
label = jax.numpy.ones((5,100,1))
model = nnx.Sequential(nnx.Linear(20, 30, rngs= nnx.Rngs(0)),
                        nnx.Linear(30, 1, rngs= nnx.Rngs(1)))

#Unsharded data/model
jax.debug.visualize_array_sharding(data[0])

Image

jax.debug.visualize_array_sharding(model.layers[0].kernel.value)

Image


#shard data
mesh = jax.make_mesh((jax.device_count(),), ("batch", ))
data_sharding = NamedSharding(mesh, PartitionSpec(None,"batch"))

sharded_data = jax.device_put(data, data_sharding)
sharded_label = jax.device_put(label, data_sharding)

#shard model
def create_sharded_model(model):
  state = nnx.state(model) # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
    sharded_model = create_sharded_model(model)

tx = optax.adam(1e-3)
state = nnx.Optimizer(model, tx)

#Sharded data/label at this point data should be sharded across GPUs
#with the model replicated (???)
jax.debug.visualize_array_sharding(sharded_data[0])

Image

jax.debug.visualize_array_sharding(sharded_model.layers[0].kernel.value)

Image

loss_fn = lambda model, x, y: optax.l2_loss(model(x),y).mean()

def step_fn(batch_data, batch_label, state):
    grads = nnx.grad(loss_fn)(state.model,batch_data,batch_label)
    state.update(grads=grads)
    return state

#combine multiple train steps into a single scan carrying over the state
scanned_train = nnx.jit(nnx.scan(step_fn, in_axes=(0,0,nnx.Carry), out_axes=(nnx.Carry),
                       transform_metadata={nnx.PARTITION_NAME:"batch"}))

#returns states after 5 scanned+jitted train steps
new_state = scanned_train(sharded_data, sharded_label, state)

Specifically I'd like to know:

  1. Is my approach to sharding a model with no annotations the best practice for replicating a model across devices?
  2. Am I correct in thinking this sharding formulation will have the replicated models run the forward pass on the subset of batch observations located on their respective GPUs?
    • Bit hard to exactly determine whats happening under the hood here.
  3. with data sharded across GPUs but a replicated model how exactly are gradients calculated/combined?
    • the repo I'm trying to mirror has the jax.lax.pmean explicitly stated in the losses.py generated loss function (line 229) but it seems like some flax magic is happening behind the scenes that I'm kinda confused about because everything appears to work without a jax.lax.pmean equivalent in my example
Teculos commented 8 hours ago

just saw that I was mistaken in thinking that there wasn't a nnx.pmap... was confused since it isn't included in the transforms documentation .

Regardless I'd be love to know if my approach roughly approximates the pmap/pmean strategy used in the mentioned repo.