Open Teculos opened 8 hours ago
just saw that I was mistaken in thinking that there wasn't a nnx.pmap... was confused since it isn't included in the transforms documentation .
Regardless I'd be love to know if my approach roughly approximates the pmap/pmean strategy used in the mentioned repo.
I'm trying to, roughly, replicate behaviour found in this repo, where they pmap a scan transform on the train step to combine multiple train steps into one function call (see run_lib.py line 124). Since this is a pre flax.nnx implementation they replicate the model and pmap over the model replicates and data (structured as [combined steps, jax.device_count(), batchsize// jax.device_count(), *data dim]).
Ergo they pmap across the second dimension and scan across the first to distribute the forward pass across GPUs, jax.lax.pmean the gradient, update the model, and iterate to the nexted step in the scan.
Since pmap has no flax.nnx equivalent my approach was to shard the data across the batch dimension (data for me is in the shape [combined steps, batch_size, *data dim]) and replicate the model on each GPU to distribute the forward pass. Although I'm not certain if I'm going about it properly. See below for a minimum example with a simple model and random data/labels.
Specifically I'd like to know:
jax.lax.pmean
explicitly stated in thelosses.py
generated loss function (line 229) but it seems like some flax magic is happening behind the scenes that I'm kinda confused about because everything appears to work without ajax.lax.pmean
equivalent in my example