jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

`numpy.memmap.flush()` in `jax` #20418

Closed hjnnjh closed 7 months ago

hjnnjh commented 7 months ago

Please:

I'm implementing Stochastic Variational Inference algorithm of my model, during which, to avoid OOM, I try to use jnp.load('*.npy', mmap_mode='r+') load a huge batch parameter array from disk. However, when I attempt to do flush() on the loaded array I got 'ArrayImpl' object has no attribute 'flush'. Does flush() not be implemented in jax?

hjnnjh commented 7 months ago

I found that using numpy.memmap load parameters then convert it to jax.Array seems feasible and efficient acceptable(My parameters are saved on a M.2 solid disk). After some update steps, I convert it back to ndarray and do flush().

hawkinsp commented 7 months ago

Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling np.asarray on it and then using NumPy's facilities to do it. jnp.load and jnp.save are thin wrappers around the NumPy features.

I'm not completely sure it makes sense for us to implement flush since our arrays are immutable.

hjnnjh commented 7 months ago

Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling np.asarray on it and then using NumPy's facilities to do it. jnp.load and jnp.save are thin wrappers around the NumPy features.

I'm not completely sure it makes sense for us to implement flush since our arrays are immutable.

Thanks for your reply. I mean is there some kind of np.memmap in jax?

hjnnjh commented 7 months ago

Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling np.asarray on it and then using NumPy's facilities to do it. jnp.load and jnp.save are thin wrappers around the NumPy features. I'm not completely sure it makes sense for us to implement flush since our arrays are immutable.

Thanks for your reply. I mean is there some kind of np.memmap in jax?

I just finished the SVI algorithm and ran some experiments, in which I found there is no need to implement the so-called jnp.memmap because using np.asarray to convert parameters and then update the corresponding np.memmap object is efficient enough. It can save a lot of GPU memory while completing each step very fast. I think I can close this issue. Thank you! @hawkinsp