Open danijar opened 2 years ago
can you share a bit more, what platform/accelerator are you using ? is there a code snippet you can share to reproduce the issue?
Hi @mjsML, this was on TPU (DF 2x2) and the snippet I used was this, where the leading axis of the param tree is pmapped:
params = tree_map(lambda x: x[:4], params)
This was to get the 4 copies of the parameters of the first 4 out of 8 cores.
The issue came up in the internal JAX chat where we found that indexing in the device axis has no overhead but slicing does --- Matt asked me to open an issue on Github.
Let's say we get a ShardedDeviceArray
sde
back from pmap with axis=0 being the device axis.Currently, indexing (
sde[3]
) is fast but slicing (sde[:4]
) triggers unnecessary data movement.This can be really slow, especially when slicing a parameter tree with many leaves, and should be easy to avoid.