Closed mofeing closed 1 year ago
If you change unpack code to tensor._data = array
then vmap
seems to work somewhat better, the problem is then that psi
is supplied as a list to square_norm
. Maybe one can play with the in_axes
, but it might be that the function itself needs to support 'broadcasting' the correct axes, where is here list[pytree]
? Not totally sure!
In general happy to add the pytree register, the only downside is that one probably has to try importing jax
and eagerly registering which is maybe not ideal startup wise, maybe it could be an opt in function, qtn.enable_jax_pytree_support()
.
If you change unpack code to
tensor._data = array
thenvmap
seems to work somewhat better, the problem is then that psi is supplied as a list tosquare_norm
. Maybe one can play with thein_axes
, but it might be that the function itself needs to support 'broadcasting' the correct axes, where is herelist[pytree]
? Not totally sure!
I've been researching a lil and apparently JAX only allows mapping of JAX arrays, so we cannot vmap
on a list[MatrixProductStates]
.
But vmap
on a struct-of-arrays should be allowed. I'm doing some experiments with this and will comment when I have some results.
In general happy to add the pytree register, the only downside is that one probably has to try importing jax and eagerly registering which is maybe not ideal startup wise, maybe it could be an opt in function, qtn.enable_jax_pytree_support().
This could be solved using import hooks injected into sys.meta_path
such that the register_pytree_node
s are called when import jax
is called.
I do have a tensor_network_compile
implementation lying around that converts any TN function into an array function, compiles that using autoray.autojit
, which you could then vmap. But again that would be more using TNs to orchestrate array logic rather than as jax nodes themselves.
This could be solved using import hooks injected into sys.meta_path such that the register_pytree_nodes are called when import jax is called.
I see, I'm not familiar with this function/submodule!
So I made jax.vmap
work with quimb
! The condition is that the inputs and outputs of the function must be JAX arrays, and some axis of the input args is the one over which you parallelize the map.
Also, I need to fix some lines in Vectorizer.unpack
such that it does array = autoray.do("reshape", array, shape)
instead of array.shape = shape
because the JAX tracer doesn't like it.
Here is an example:
import jax
import quimb.tensor as qtn
L = 10
batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]
phis_arrays = [jax.numpy.asarray([state.tensors[i].data for state in states]) for i in range(L)]
def overlap(*arrays):
phi = qtn.MatrixProductState(arrays)
return phi.H @ psi
jax.vmap(vect_overlap)(*phis_arrays)
jax.vmap(jax.grad(overlap, argnums=list(range(L))))(*phis_arrays)
In this sense, the solution would be similar to the MakeArrayFn
class that wraps a function so that it accepts arrays instead of a TN but you need an extra "batching" dimension (which is what you specify in in_axes
).
I have tried the same but with vectorized TNs but I'm running into problems. Specifically, the vectorizer is returning me the same vector for different TNs. Here is the code I've tried. Any idea what I'm doing wrong?
batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]
vectorizer = qtn.optimize.Vectorizer(phis[0].arrays)
vect_phis = jax.numpy.asarray([vectorizer.pack(phi.arrays) for phi in phis])
def vect_overlap(vect_arrays):
arrays = vectorizer.unpack(vect_arrays)
phi = qtn.MatrixProductState(arrays)
return phi.H @ psi
jax.vmap(vect_overlap)(vect_phis)
I see, I'm not familiar with this function/submodule!
Yeah, it's one of this hacky, hidden Python modules. Here is an example of how to use import hooks: https://stackoverflow.com/a/54456931
So I made jax.vmap work with quimb! The condition is that the inputs and outputs of the function must be JAX arrays, and some axis of the input args is the one over which you parallelize the map.
Nice, yes exposing a 'raw array' function interface might be generally useful - that's basically what the tn compiler decorator does too.
Regarding Vectorizer
, it currently is specifically for vectorizing to a single, real, double precision numpy array to be used with the scipy/nlopt etc optimizers - i.e. it won't work and is not intended for other purposes such as being traced through, despite its rather general name...
A class that simply goes from arrays to flattened vector form should actually simpler than this, and you'd ignore all the dtype stuff and simply use concatenate
on the flattened arrays, rather than reusing memory as Vectorizer
does, so that its traceable etc.
Is your feature request related to a problem?
I've been running into problems when trying to make JAX and Quimb cooperate. JAX only supports its own native formats (e.g. jax.numpy.array). Quimb surpasses this obstacle by manually:
jax.numpy.array
sjax
these steps are performed by the
JaxHandler
class which is used by theTNOptimizer
class, and it seems to work well for the purposes of theTNOptimizer
but I'm running into problems when I want to do some more stuff.mainly, if you directly pass a TN to a function transformed by JAX (i.e.
jax.grad
,jax.jit
,jax.vmap
) it crashes completely because JAX does not recognize any TN class as compatible.the solution is to register
TensorNetwork
and subclasses as JAX-compatible. fortunately, we can do it by using thejax.tree_util.register_pytree_node
method.Describe the solution you'd like
This code is working for me when calling
jax.jit
andjax.grad
, but not when callingjax.vmap
.I'm thinking on how to generalize this for
TensorNetwork
and subclasses, becauseregister_pytree_node
does not ascend through the class hierarchy. one solution is to callregister_pytree_node
insideTensorNetwork.__init_subclass__
such that it is called every time we inherit from it.Describe alternatives you've considered
No response
Additional context
These are the examples I have tested against: