Open joeryjoery opened 2 years ago
Take a look at jax.experimental.maps.xmap
, which is designed for exactly this sort of thing:
https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html
https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.xmap.html
You can chunk vmap by making use of SerialLoop
.
This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked
This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked
Thanks that is an excellent reference. I am of the opinion though that Jax could implement this by default given how generally useful this is.
Hi, more than a year later ;P, are there any plans on implementing this by default in Jax? Or would the authors be open for a PR?
@mattjj we were just talking about this :)
If you do get to it, I would love if you also add support for vjp
over 'chunked' axes like we did in https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vjp_chunked.html
+1 for adding this feature, it would be super useful in several projects for me.
Discussion https://github.com/google/jax/discussions/18398 asks for this as well.
+1 upvote for implementing this
It would be nice if JAX's compiler could automatically convert computations from parallel to sequential when necessary, given a known memory constraint. Has this possibility been discussed anywhere?
jax.lax.map
now has a batch_size
argument that will chuck the computation and internally utilize vmap
to operate over each batch in parallel. See #19614.
I tried using JAX as an alternative to numba, since it is much easier to work in, and I can switch between CPU and GPU using a flag. The biggest problem I am facing now, is that vmap
tries to put all of the data onto the GPU at the same time. I am already using vmap to apply the scalar function onto a 2D array, so I was wondering if and how map
could help me to batch the data, so I am not getting of memory errors!
Try the batch_size
argument to jax.lax.map
? (Added in #19614.)
Think of it as a sequential loop over batches, where each batch is vmapped.
Occasionally I run into the problem that batch-sizes are too large for GPU memory and using the current public API of Jax I either have to commit to one of the extremes of using
vmap
or using the much more limitedscan
. (I'm still relatively new to Jax, so please correct me if I'm wrong). Could there perhaps be achunk
argument to thevmap
function that limits how much is passed to thevmap
ped function at any time?I tried implementing this, and have the following mock-up. At the moment this is based on a
num_chunks
parameter, but this could be more flexible such as in terms ofchunk_size
. It splits or repeats the inputs of thevmap
ped function based on the givenin_axes
, pads all input arrays uniformly with zeroes along the batch dimension, and finally for-loop the chunks through a conventionalvmap
function. The results are finally concatenated based on their canonical slices.I don't have all the functionalities of the normal
vmap
implemented, but it works in a slightly more narrow scope. Could something like this be supported by the public Jax API?