google-deepmind / penzai

A JAX research toolkit for building, editing, and visualizing neural networks.
https://penzai.readthedocs.io/
Apache License 2.0
1.64k stars 50 forks source link

FR: batched mapping for `named_axes.nmap` #71

Open amifalk opened 1 month ago

amifalk commented 1 month ago

Sometimes nmap'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )

Jax now supports symantics for batched vmapping with jax.lax.map. This would be awesome to add to Penzai!

danieldjohnson commented 1 month ago

Hm, interesting idea!

One question is what the API for this should be. Some ideas:

amifalk commented 1 month ago

It occurs to me that adding batched mapping directly to nmap may not be the right thing to do here.

I'm thinking about this feature in the context of batching model evaluations on a grid of data inputs. This is important when serving models with many requests coming in or performing a grid search over hyperparameters, just to give two examples. In this context, I only ever want to batch over the outermost code.

You might imagine that an arbitrary neural network in Penzai looks something like this under the hood:

def eval_nn(x): 
    x = nmap(custom_layer_1)(x)
    x = nmap(custom_layer_2)(x)
    x = nmap(custom_layer_3)(x)
    return x 

When we batch-nmap this over a grid, the batching will be pushed inside each nmap if we use any global configuration applied to the nmap operator (as in the binding or context manager proposal).

e.g:

def eval_nn(x): 
    x = nmap(custom_layer_1)(x)
    x = nmap(custom_layer_2)(x)
    x = nmap(custom_layer_3)(x)
    return x 

inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})

with pz.nx.batch_nmap(batch_sizes={"batch": 10}):
    eval_nn(x)

will evaluate to:

def eval_nn(x): 
    x = nmap(custom_layer_1, batch=10)(x)
    x = nmap(custom_layer_2, batch=10)(x)
    x = nmap(custom_layer_3, batch=10)(x)
    return x 

I made a little micro-benchmark and piping the map ~40% slower than batch-mapping the outside on Jax 0.4.31 with an Nvidia GeForce 4090, likely due to the transfer overheads between the host and device after each scan.

```python #%% import jax import jax.numpy as jnp import jax.random as random def layer(arr): return jnp.matmul(arr, arr - jnp.mean(arr)) BATCH_SIZE = 20 def map_each_layer(batch_of_arrs): batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE) batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE) batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE) batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE) batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE) return batch_of_arrs def all_layers(arr): arr = layer(arr) arr = layer(arr) arr = layer(arr) arr = layer(arr) arr = layer(arr) return arr def map_all_layers(batch_of_arrs): return jax.lax.map(all_layers, batch_of_arrs, batch_size=BATCH_SIZE) map_each_layer_jit = jax.jit(map_each_layer) map_all_layers_jit = jax.jit(map_all_layers) batch_of_arrs = random.normal(random.PRNGKey(0), (500, 100, 100)) #%% %time map_each_layer_jit(batch_of_arrs).block_until_ready() %time map_all_layers_jit(batch_of_arrs).block_until_ready() #%% %timeit map_each_layer_jit(batch_of_arrs).block_until_ready() %timeit map_all_layers_jit(batch_of_arrs).block_until_ready() # %% ```
143 ms ± 72.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # internal piping of scan
105 ms ± 68.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # one map on the outside

Instead, maybe we should have some function pz.nx.batch (this could also be a decorator) that batch-nmaps only over the specified axes and does not pipe in the mapping to the internal functions.

Like lax.map, this would just be syntactic sugar over lax.scan. Here's some rough pseudo-code:

def batch(fn, xs, batch_sizes=None):
    if batch_sizes is None:
        batch_sizes = {}
    batch_axes, axis_sizes = tuple(batch_sizes.keys()), tuple(batch_sizes.items())
    xs = xs.untag(*batch_axes).reshape((-1,) + axis_sizes)

    return pz.nx.stack(pz.nx.scan(fn, axes=batch_axes), axes=batch_axes)

# usage: 
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
pz.nx.batch(eval_nn, batch_sizes={"batch": 10})(inputs)

What do you think?

amifalk commented 1 month ago

I wrote up a prototype of this functionality described above ^ here: https://gist.github.com/amifalk/e21059da7f0c0ecb3db8240604413998

I realized there's no benefit to allowing different batch sizes for different axes given that they're all evaluated independently. In the worst case, the remainder named array won't have the same shape as the batch array, so it will not be possible to broadcast them back together.

danieldjohnson commented 1 month ago

I agree that having a single scan at the outside seems better than having a number of smaller scans in the inside for this use case. Thanks for running the benchmark, and for taking a stab at the implementation!

I wonder if the best solution here would be to aim to match the semantics of jax.lax.map with a named-axis version, similar to the relationship between jax.lax.scan and pz.nx.scan. If so, this would suggest that:

One question is what the name of this should be:

What do you think?

(This makes me think of a related question: would it be useful to provide pz.nx.vmap which is like jax.vmap but maps over a single named axis in parallel, keeping all the other named axes? This would pull the batching out instead of pushing it inward to the inner functions. I'm not sure if there's much of a reason to have this, though. Perhaps avoiding name conflicts?)

amifalk commented 1 month ago

Agreed w.r.t. all points emulating jax.lax.map. My gut says the name pz.nx.map is fine given that it will have a different function signature then nmap, but I don't have a strong feeling between that and serial_map.

On pz.nx.vmap, I would guess the XLA wouldn't change and it strikes me as a bit of an anti-pattern that would make functions more brittle and might confuse new users (once you nmap, you never go back!).