Qiskit-Extensions / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-extensions.github.io/qiskit-dynamics/
Apache License 2.0
97 stars 60 forks source link

Add grid_map utility for managing JAX parallelization/vectorization #355

Open DanPuzzuoli opened 3 months ago

DanPuzzuoli commented 3 months ago

Summary

grid_map is mapping utility for mapping a function over a "grid" of argument values. E.g. for two arrays a and b,

grid_map(f, a, b)[i, j] = f(a[i], b[j])

This works for an arbitrary number of arguments, and more generally works on JAX PyTrees using standard conventions for mapping over PyTrees (explained in the function doc string). I.e. for a and b general PyTrees with appropriate leaf shape (all leaves must have equal leading dimension size for the mapping makes sense), the above expression holds so long as we interpret

v[idx] = tree_map(lambda x: x[idx], v)

for v being any of the indexed objects in the previous expression. I.e. v[idx] is the PyTree resulting from indexing every leaf of v with idx (and we are slightly abusing notation to allow idx to be a multi-index).

Aside from the usefulness of the above form of mapping, the main point of grid_map is to offer control over how the evaluations of f get parallelized. Under the hood, it utilizes JAX's xmap to execute the mapping using a combination of device parallelization and vectorization (which they describe in their documentation as an "interpolation" between pmap and vmap). It makes natural default choices based on the device types, and a user can directly control these with optional arguments.

For Dynamics this will be of use for internally controlling parallelization in the package, whether using CPU or GPU, and generally will be useful for users to have direct access to.

Details and comments

This is currently a work in progress. This helper function was written to control parallelization/mapping in some research projects, and its design/implementation will need to be revisited for integration into Dynamics.

General to do:

Design questions:

Update (23/4/24):