Closed simon-bachhuber closed 3 months ago
Thanks for raising this.
pmap
's out_axes=None
was only ever partially implemented, and I don't think reverse-mode autodiff over it was ever implemented. It could be implemented, though, in principle! (The analogous thing works with vmap
, and in pmap
with things other than reverse-mode, so maybe you were remembering one of those?)
Is it more or less equivalent to have an dimension reduction outside of the pmap?
When the output of a pmap
function is equal across all mapped instances of the function, then out_axes=None
means "just choose one of these (equal) values". That is, when all values are equal pmap(f, ..., out_axes=None)(xs) == pmap(f, ..., out_axes=0)(xs)[0]
. I don't expect using out_axes=None
to be any more efficient either, though I could be wrong.
So you might consider just doing the slicing yourself, i.e. replace any None
entries of out_axes
with 0
, and then pick any index along those axes of the output.
Just a few days ago we've started landing an upgrade to pmap
called shmap
(or shard_map
); you can read the design doc here. It has an analogue of out_axes=None
, basically having an out_spec
entry which doesn't mention one of the mesh axis names. Not only does that work (with reverse-mode AD), it's also typically more efficient than pmap
. In particular, with shmap
not mentioning a mesh axis name in an out_spec basically switches the output to have a replicated layout along that mesh axis, without having to move or delete any buffers.
I expect we'll ultimately implement pmap
in terms of shard_map
, and at that time out_axes=None
will probably start working. But until then I don't think we plan to work more on pmap
.
Ha, you are probably right. I just tried it in an old conda env with jax=0.2.26. The assertion is the same there.
Just as a side-note, the assertion is a little unhelpful; maybe it would be better to simply print out that currently pmap
does not support None
in out_axes
but either way its a minor point.
That's really interesting, thanks for the link. Especially the pjit
sounds like a nice addition -- i really don't have to right the parallel logic myself ;)
Anyways, thank you for working on this amazing piece of software, and feel free to close this any time.
Thanks for the kind words :)
Just as a side-note, the assertion is a little unhelpful; maybe it would be better to simply print out that currently pmap does not support None in out_axes but either way its a minor point.
You're absolutely right! Let's do that.
After working on this for more than a year, #22585 will finally fix it.
Dear JAX team,
is the following no longer considered a reasonable pattern? I could swear this used to work and found it quite concise. Would be nice to have it back.
Is it more or less equivalent to have an dimension reduction outside of the pmap?