jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.43k stars 2.79k forks source link

Unnecessary data movement when slicing ShardedDeviceArray #11920

Open danijar opened 2 years ago

danijar commented 2 years ago

Let's say we get a ShardedDeviceArray sde back from pmap with axis=0 being the device axis.

Currently, indexing (sde[3]) is fast but slicing (sde[:4]) triggers unnecessary data movement.

This can be really slow, especially when slicing a parameter tree with many leaves, and should be easy to avoid.

mjsML commented 2 years ago

can you share a bit more, what platform/accelerator are you using ? is there a code snippet you can share to reproduce the issue?

danijar commented 2 years ago

Hi @mjsML, this was on TPU (DF 2x2) and the snippet I used was this, where the leading axis of the param tree is pmapped:

params = tree_map(lambda x: x[:4], params)

This was to get the 4 copies of the parameters of the first 4 out of 8 cores.

The issue came up in the internal JAX chat where we found that indexing in the device axis has no overhead but slicing does --- Matt asked me to open an issue on Github.