jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
486 stars 108 forks source link

Register TensorNetwork and subclasses as JAX-compatible types #148

Closed mofeing closed 1 year ago

mofeing commented 1 year ago

Is your feature request related to a problem?

I've been running into problems when trying to make JAX and Quimb cooperate. JAX only supports its own native formats (e.g. jax.numpy.array). Quimb surpasses this obstacle by manually:

  1. extracting the arrays of the TNs
  2. converting them to jax.numpy.arrays
  3. create a partial of the function that accepts the arrays and reconstructs the TN
  4. changing the contract backend to jax

these steps are performed by the JaxHandler class which is used by the TNOptimizer class, and it seems to work well for the purposes of the TNOptimizer but I'm running into problems when I want to do some more stuff.

mainly, if you directly pass a TN to a function transformed by JAX (i.e. jax.grad, jax.jit, jax.vmap) it crashes completely because JAX does not recognize any TN class as compatible.

the solution is to register TensorNetwork and subclasses as JAX-compatible. fortunately, we can do it by using the jax.tree_util.register_pytree_node method.

Describe the solution you'd like

This code is working for me when calling jax.jit and jax.grad, but not when calling jax.vmap.

def pack(obj):
    return (obj.arrays, obj)

def unpack(aux, children):
    obj = aux.copy()
    for tensor, array in zip(obj.tensors, children):
        tensor.modify(data=array)

    return obj

jax.tree_util.register_pytree_node(qtn.MatrixProductState, pack, unpack)

I'm thinking on how to generalize this for TensorNetwork and subclasses, because register_pytree_node does not ascend through the class hierarchy. one solution is to call register_pytree_node inside TensorNetwork.__init_subclass__ such that it is called every time we inherit from it.

Describe alternatives you've considered

No response

Additional context

These are the examples I have tested against:

import jax
import quimb.tensor as qtn

mps = qtn.MPS_rand_state(10, 4, normalize=True)

def square_norm(psi):
    return psi.H @ psi

jax.jit(jax.grad(square_norm))(mps)
import jax
import quimb.tensor as qtn

mps_list = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(16)]

jax.vmap(square_norm)(mps_list)
jcmgray commented 1 year ago

If you change unpack code to tensor._data = array then vmap seems to work somewhat better, the problem is then that psi is supplied as a list to square_norm. Maybe one can play with the in_axes, but it might be that the function itself needs to support 'broadcasting' the correct axes, where is here list[pytree]? Not totally sure!

In general happy to add the pytree register, the only downside is that one probably has to try importing jax and eagerly registering which is maybe not ideal startup wise, maybe it could be an opt in function, qtn.enable_jax_pytree_support().

mofeing commented 1 year ago

If you change unpack code to tensor._data = array then vmap seems to work somewhat better, the problem is then that psi is supplied as a list to square_norm. Maybe one can play with the in_axes, but it might be that the function itself needs to support 'broadcasting' the correct axes, where is here list[pytree]? Not totally sure!

I've been researching a lil and apparently JAX only allows mapping of JAX arrays, so we cannot vmap on a list[MatrixProductStates].

But vmap on a struct-of-arrays should be allowed. I'm doing some experiments with this and will comment when I have some results.

In general happy to add the pytree register, the only downside is that one probably has to try importing jax and eagerly registering which is maybe not ideal startup wise, maybe it could be an opt in function, qtn.enable_jax_pytree_support().

This could be solved using import hooks injected into sys.meta_path such that the register_pytree_nodes are called when import jax is called.

jcmgray commented 1 year ago

I do have a tensor_network_compile implementation lying around that converts any TN function into an array function, compiles that using autoray.autojit, which you could then vmap. But again that would be more using TNs to orchestrate array logic rather than as jax nodes themselves.

This could be solved using import hooks injected into sys.meta_path such that the register_pytree_nodes are called when import jax is called.

I see, I'm not familiar with this function/submodule!

mofeing commented 1 year ago

So I made jax.vmap work with quimb! The condition is that the inputs and outputs of the function must be JAX arrays, and some axis of the input args is the one over which you parallelize the map.

Also, I need to fix some lines in Vectorizer.unpack such that it does array = autoray.do("reshape", array, shape) instead of array.shape = shape because the JAX tracer doesn't like it.

Here is an example:

import jax
import quimb.tensor as qtn

L = 10
batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]

phis_arrays = [jax.numpy.asarray([state.tensors[i].data for state in states]) for i in range(L)]

def overlap(*arrays):
    phi = qtn.MatrixProductState(arrays)
    return phi.H @ psi

jax.vmap(vect_overlap)(*phis_arrays)
jax.vmap(jax.grad(overlap, argnums=list(range(L))))(*phis_arrays)

In this sense, the solution would be similar to the MakeArrayFn class that wraps a function so that it accepts arrays instead of a TN but you need an extra "batching" dimension (which is what you specify in in_axes).

I have tried the same but with vectorized TNs but I'm running into problems. Specifically, the vectorizer is returning me the same vector for different TNs. Here is the code I've tried. Any idea what I'm doing wrong?

batch_size = 32
psi = qtn.MPS_rand_state(10, 4, normalize=True)
phis = [qtn.MPS_rand_state(10, 4, normalize=True) for _ in range(batch_size)]

vectorizer = qtn.optimize.Vectorizer(phis[0].arrays)
vect_phis = jax.numpy.asarray([vectorizer.pack(phi.arrays) for phi in phis])

def vect_overlap(vect_arrays):
    arrays = vectorizer.unpack(vect_arrays)
    phi = qtn.MatrixProductState(arrays)
    return phi.H @ psi

jax.vmap(vect_overlap)(vect_phis)

I see, I'm not familiar with this function/submodule!

Yeah, it's one of this hacky, hidden Python modules. Here is an example of how to use import hooks: https://stackoverflow.com/a/54456931

jcmgray commented 1 year ago

So I made jax.vmap work with quimb! The condition is that the inputs and outputs of the function must be JAX arrays, and some axis of the input args is the one over which you parallelize the map.

Nice, yes exposing a 'raw array' function interface might be generally useful - that's basically what the tn compiler decorator does too.

Regarding Vectorizer, it currently is specifically for vectorizing to a single, real, double precision numpy array to be used with the scipy/nlopt etc optimizers - i.e. it won't work and is not intended for other purposes such as being traced through, despite its rather general name...

A class that simply goes from arrays to flattened vector form should actually simpler than this, and you'd ignore all the dtype stuff and simply use concatenate on the flattened arrays, rather than reusing memory as Vectorizer does, so that its traceable etc.