Open amifalk opened 1 month ago
Hm, interesting idea!
One question is what the API for this should be. Some ideas:
nmap
: Something like nmap(fn, batch_sizes={"foo":2, "bar":4})
. This would be easy to implement, but a disadvantage is that you'd need to plumb through the batch sizes for each axis into each computation. For instance, if you wanted to add a new batch axis but map over it, you'd need to modify all of the calls to nmap
inside the function you are calling, which would be fairly annoying.nmap
call could automatically pick up and use the batch size. I think this would be pretty complex to implement, though. It also doesn't compose well with tag
/untag
, although I guess the rule could just be that untagging an axis resets its batching size?nmap
calls inside the context manager would read from it. That would make it easier to specify batch sizes on a per-name level. However, it might not work well with JAX tracing, since tracers would have to be aware of the context manager somehow.It occurs to me that adding batched mapping directly to nmap
may not be the right thing to do here.
I'm thinking about this feature in the context of batching model evaluations on a grid of data inputs. This is important when serving models with many requests coming in or performing a grid search over hyperparameters, just to give two examples. In this context, I only ever want to batch over the outermost code.
You might imagine that an arbitrary neural network in Penzai looks something like this under the hood:
def eval_nn(x):
x = nmap(custom_layer_1)(x)
x = nmap(custom_layer_2)(x)
x = nmap(custom_layer_3)(x)
return x
When we batch-nmap this over a grid, the batching will be pushed inside each nmap
if we use any global configuration applied to the nmap operator (as in the binding or context manager proposal).
e.g:
def eval_nn(x):
x = nmap(custom_layer_1)(x)
x = nmap(custom_layer_2)(x)
x = nmap(custom_layer_3)(x)
return x
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
with pz.nx.batch_nmap(batch_sizes={"batch": 10}):
eval_nn(x)
will evaluate to:
def eval_nn(x):
x = nmap(custom_layer_1, batch=10)(x)
x = nmap(custom_layer_2, batch=10)(x)
x = nmap(custom_layer_3, batch=10)(x)
return x
I made a little micro-benchmark and piping the map ~40% slower than batch-mapping the outside on Jax 0.4.31 with an Nvidia GeForce 4090, likely due to the transfer overheads between the host and device after each scan.
143 ms ± 72.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # internal piping of scan
105 ms ± 68.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # one map on the outside
Instead, maybe we should have some function pz.nx.batch
(this could also be a decorator) that batch-nmaps only over the specified axes and does not pipe in the mapping to the internal functions.
Like lax.map
, this would just be syntactic sugar over lax.scan
. Here's some rough pseudo-code:
def batch(fn, xs, batch_sizes=None):
if batch_sizes is None:
batch_sizes = {}
batch_axes, axis_sizes = tuple(batch_sizes.keys()), tuple(batch_sizes.items())
xs = xs.untag(*batch_axes).reshape((-1,) + axis_sizes)
return pz.nx.stack(pz.nx.scan(fn, axes=batch_axes), axes=batch_axes)
# usage:
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
pz.nx.batch(eval_nn, batch_sizes={"batch": 10})(inputs)
What do you think?
I wrote up a prototype of this functionality described above ^ here: https://gist.github.com/amifalk/e21059da7f0c0ecb3db8240604413998
I realized there's no benefit to allowing different batch sizes for different axes given that they're all evaluated independently. In the worst case, the remainder named array won't have the same shape as the batch array, so it will not be possible to broadcast them back together.
I agree that having a single scan
at the outside seems better than having a number of smaller scans in the inside for this use case. Thanks for running the benchmark, and for taking a stab at the implementation!
I wonder if the best solution here would be to aim to match the semantics of jax.lax.map
with a named-axis version, similar to the relationship between jax.lax.scan
and pz.nx.scan
. If so, this would suggest that:
jax.lax.map
and pz.nx.scan
)jax.lax.map
, the mapped-over axis should be a named axis (like pz.nx.scan
)batch_size
option should be optional (similar to jax.lax.map
), and default to 1f
to keep it as-isOne question is what the name of this should be:
pz.nx.map
would fit with pz.nx.scan
, but might get confused with pz.nx.nmap
?pz.nx.serial_map
is more explicit but more verbose.What do you think?
(This makes me think of a related question: would it be useful to provide pz.nx.vmap
which is like jax.vmap
but maps over a single named axis in parallel, keeping all the other named axes? This would pull the batching out instead of pushing it inward to the inner functions. I'm not sure if there's much of a reason to have this, though. Perhaps avoiding name conflicts?)
Agreed w.r.t. all points emulating jax.lax.map
. My gut says the name pz.nx.map
is fine given that it will have a different function signature then nmap
, but I don't have a strong feeling between that and serial_map
.
On pz.nx.vmap
, I would guess the XLA wouldn't change and it strikes me as a bit of an anti-pattern that would make functions more brittle and might confuse new users (once you nmap, you never go back!).
Sometimes
nmap
'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )Jax now supports symantics for batched vmapping with
jax.lax.map
. This would be awesome to add to Penzai!