jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Assertion error for pmap out_axes with nested None value #14296

Closed simon-bachhuber closed 3 months ago

simon-bachhuber commented 1 year ago

Dear JAX team,

is the following no longer considered a reasonable pattern? I could swear this used to work and found it quite concise. Would be nice to have it back.

Is it more or less equivalent to have an dimension reduction outside of the pmap?

import jax 
import jax.numpy as jnp
import functools as ft 

@ft.partial(jax.grad, has_aux=True)
@ft.partial(jax.pmap, in_axes=(None, 0, 0), out_axes=(None, 0), axis_name="pmap")
def loss_fn(params, X, y):
    yhat = params @ X
    mse = jnp.mean((y - yhat)**2)
    return jax.lax.pmean(mse, "pmap"), yhat
mattjj commented 1 year ago

Thanks for raising this.

pmap's out_axes=None was only ever partially implemented, and I don't think reverse-mode autodiff over it was ever implemented. It could be implemented, though, in principle! (The analogous thing works with vmap, and in pmap with things other than reverse-mode, so maybe you were remembering one of those?)

Is it more or less equivalent to have an dimension reduction outside of the pmap?

When the output of a pmap function is equal across all mapped instances of the function, then out_axes=None means "just choose one of these (equal) values". That is, when all values are equal pmap(f, ..., out_axes=None)(xs) == pmap(f, ..., out_axes=0)(xs)[0]. I don't expect using out_axes=None to be any more efficient either, though I could be wrong.

So you might consider just doing the slicing yourself, i.e. replace any None entries of out_axes with 0, and then pick any index along those axes of the output.

Just a few days ago we've started landing an upgrade to pmap called shmap (or shard_map); you can read the design doc here. It has an analogue of out_axes=None, basically having an out_spec entry which doesn't mention one of the mesh axis names. Not only does that work (with reverse-mode AD), it's also typically more efficient than pmap. In particular, with shmap not mentioning a mesh axis name in an out_spec basically switches the output to have a replicated layout along that mesh axis, without having to move or delete any buffers.

I expect we'll ultimately implement pmap in terms of shard_map, and at that time out_axes=None will probably start working. But until then I don't think we plan to work more on pmap.

simon-bachhuber commented 1 year ago

Ha, you are probably right. I just tried it in an old conda env with jax=0.2.26. The assertion is the same there.

Just as a side-note, the assertion is a little unhelpful; maybe it would be better to simply print out that currently pmap does not support None in out_axes but either way its a minor point.

That's really interesting, thanks for the link. Especially the pjit sounds like a nice addition -- i really don't have to right the parallel logic myself ;)

Anyways, thank you for working on this amazing piece of software, and feel free to close this any time.

mattjj commented 1 year ago

Thanks for the kind words :)

Just as a side-note, the assertion is a little unhelpful; maybe it would be better to simply print out that currently pmap does not support None in out_axes but either way its a minor point.

You're absolutely right! Let's do that.

mattjj commented 3 months ago

After working on this for more than a year, #22585 will finally fix it.