Open DanPuzzuoli opened 7 months ago
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.
Summary
grid_map
is mapping utility for mapping a function over a "grid" of argument values. E.g. for two arraysa
andb
,This works for an arbitrary number of arguments, and more generally works on JAX PyTrees using standard conventions for mapping over PyTrees (explained in the function doc string). I.e. for
a
andb
generalPyTree
s with appropriate leaf shape (all leaves must have equal leading dimension size for the mapping makes sense), the above expression holds so long as we interpretfor
v
being any of the indexed objects in the previous expression. I.e.v[idx]
is the PyTree resulting from indexing every leaf ofv
withidx
(and we are slightly abusing notation to allowidx
to be a multi-index).Aside from the usefulness of the above form of mapping, the main point of
grid_map
is to offer control over how the evaluations off
get parallelized. Under the hood, it utilizes JAX'sxmap
to execute the mapping using a combination of device parallelization and vectorization (which they describe in their documentation as an "interpolation" betweenpmap
andvmap
). It makes natural default choices based on the device types, and a user can directly control these with optional arguments.For Dynamics this will be of use for internally controlling parallelization in the package, whether using CPU or GPU, and generally will be useful for users to have direct access to.
Details and comments
This is currently a work in progress. This helper function was written to control parallelization/mapping in some research projects, and its design/implementation will need to be revisited for integration into Dynamics.
General to do:
Design questions:
max_vmap_size
anddevices
)? The other option is to allow the user to specify these optional arguments anywhere "internal" parallelization is an option (e.g. JAX-based pulse simulation).non_jax_argnums
argument being kind of clunky - can revisit whether we want to include this.Update (23/4/24):
xmap
has been deprecated, so the inner-workings ofgrid_map
will need to be rethought. I think shard map may be the natural replacement.