jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
467 stars 107 forks source link

TNOptimizer, unable to use jax autodiff with dask executor #84

Open chris-n-self opened 3 years ago

chris-n-self commented 3 years ago

I don't seem to be able to use a Jax autodiff backend with a TNOptimizer object and also use a dask executor to parallelise over a multi-component loss function.

Versions: quimb -- 1.3.0+276.gee67688 cotengra -- up to date with GitHub version jax -- 0.2.10 jaxlib -- 0.1.62 dask -- 2021.3.0 cloud pickle -- 1.6.0

Example that gives me the problem:

import functools

import quimb as qu
import quimb.tensor as qtn

from cotengra.parallel import get_pool

n = 4
depth = 2
circ = qtn.circ_ansatz_1D_brickwork(n,depth)

J = 1
hX = 0.5
nn_edges = [ (i,(i+1)%n) for i in range(n) ]
ops = { nn: J * qu.pauli('Z') & qu.pauli('Z') for nn in nn_edges }
ops.update({ i: -1*hX * qu.pauli('X') for i in range(n) })

lightcone_tags = {where: circ.get_reverse_lightcone_tags(where) for where in ops}

# function that takes the input TN and computes the ith loss
def loss_i(psi, where, ops):
    tags = lightcone_tags[where]
    ket  = psi.select(tags, 'any')
    bra = ket.H
    expec = ket.gate(ops[where], where) | bra
#     return do('real', expec.contract(all, optimize='auto-hq'))
    return do('real', expec.contract(all, optimize='auto-hq'))

# since they are a sum we can evaluate them separately
loss_fns = [
    functools.partial(loss_i, where=where)
    for where in ops
]

dask_client = get_pool(n_workers=4, maybe_create=True, backend='dask',)

tnopt = qtn.TNOptimizer(
    circ.psi,
    loss_fn=loss_fns,
    tags=['U3'],
    loss_constants={'ops': ops},
    autodiff_backend='jax',
    optimizer='L-BFGS-B',
    executor=dask_client,
)

circ_opt = tnopt.optimize_basinhopping(n=1, nhop=1)

The end of the Traceback is:

~/anaconda3/envs/qaoadf-cotengra/lib/python3.8/site-packages/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    561     def dump(self, obj):
    562         try:
--> 563             return Pickler.dump(self, obj)
    564         except RuntimeError as e:
    565             if "recursion" in e.args[0]:

TypeError: cannot pickle 'jaxlib.xla_extension.jax_jit.CompiledFunction' object
jcmgray commented 3 years ago

Yes this seems to be a known regression with jax -- https://github.com/google/jax/issues/5043, where jitted functions can't be pickled, even by cloudpickle. I don't have any immediate solution other than to try a different autodiff_backend..