Closed hjnnjh closed 7 months ago
I found that using numpy.memmap
load parameters then convert it to jax.Array
seems feasible and efficient acceptable(My parameters are saved on a M.2 solid disk). After some update steps, I convert it back to ndarray
and do flush()
.
Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling np.asarray
on it and then using NumPy's facilities to do it. jnp.load
and jnp.save
are thin wrappers around the NumPy features.
I'm not completely sure it makes sense for us to implement flush
since our arrays are immutable.
Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling
np.asarray
on it and then using NumPy's facilities to do it.jnp.load
andjnp.save
are thin wrappers around the NumPy features.I'm not completely sure it makes sense for us to implement
flush
since our arrays are immutable.
Thanks for your reply. I mean is there some kind of np.memmap
in jax
?
Yes. Certainly on CPU JAX can exchange buffers with NumPy zero-copy, so you can save an array by calling
np.asarray
on it and then using NumPy's facilities to do it.jnp.load
andjnp.save
are thin wrappers around the NumPy features. I'm not completely sure it makes sense for us to implementflush
since our arrays are immutable.Thanks for your reply. I mean is there some kind of
np.memmap
injax
?
I just finished the SVI algorithm and ran some experiments, in which I found there is no need to implement the so-called jnp.memmap
because using np.asarray
to convert parameters and then update the corresponding np.memmap
object is efficient enough. It can save a lot of GPU memory while completing each step very fast. I think I can close this issue. Thank you! @hawkinsp
Please:
I'm implementing Stochastic Variational Inference algorithm of my model, during which, to avoid OOM, I try to use
jnp.load('*.npy', mmap_mode='r+')
load a huge batch parameter array from disk. However, when I attempt to doflush()
on the loaded array I got'ArrayImpl' object has no attribute 'flush'
. Doesflush()
not be implemented injax
?