jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.19k stars 2.77k forks source link

jnp.save in a distributed mannor through GPUDirect #15281

Open sh0416 opened 1 year ago

sh0416 commented 1 year ago

Please:

When we load and save using jax, we typically use jnp.load, jnp.save or pickle to serialize the object.

However, when we use distributed array, these operations are gathered the data through the host memory which incurs out of memory and ask too much time spending saving checkpoint.

Roughly speaking, saving 13B GPT state (param+optimizer state) requires 13min, which is equivalent to 10~100 update steps. If I save the checkpoint every 1000 steps, save operation increase the training time 10% or more.

There is a DMA (Direct Memory Access)-like technology inside CUDA which enables direct transfer between GPU memory and disk. https://developer.nvidia.com/blog/gpudirect-storage/

It would be better to use these technique to accelerate the research.

If the technology is already applied and just can be used with some jit option, please help me. I've just created a snippet but I don't know how to fill the option and whether it will works.

@jax.jit(in_sharding=???, static_argnum..?)
def save(tensor, path):
  jnp.save(tensor, path)

@jax.jit(out_sharding=???)
def load(path):
  return jnp.load(path)

Is there any other approach to handle this issue or other threads that struggled with?

hawkinsp commented 1 year ago

You might look into https://github.com/google/tensorstore , which is what many of my colleagues use for checkpointing their JAX jobs. It is particularly valuable for distributed checkpointing since each worker can save pieces of the checkpoint in parallel.

sh0416 commented 1 year ago

That is a good candidate! I will consider it.

sh0416 commented 1 year ago

I reached the gda serialization in the jax.experimental package. Saving LLaMA 65B with float16 (130GB) takes 4 minutes. Is it a decent time to consume or is there any other way to make it better? Zstd compression algorithm (which is the fastest) is used, but the compression ratio is not that great (after saving, 110GB is saved). Do I removed the compression algorithm inside the tensorstore? I think there is a default algorithm (blosc?) when I removed the compression algorithm.

sh0416 commented 1 year ago

I found out useful information in the following link.. https://google.github.io/tensorstore/driver/zarr/index.html#json-driver/zarr.metadata.compressor