jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
501 stars 109 forks source link

Trying to optimize circuit using local_expectation method in loss function #77

Closed JustinS6626 closed 3 years ago

JustinS6626 commented 3 years ago

I have built the circuit structure that you showed in

https://github.com/jcmgray/quimb/issues/76

and what I am trying to do now is to optimize the circuit using the TNOptimizer class with basin-hopping.

The code for my approach is as follows:

import quimb as qu import quimb.tensor as qtn import numpy as np import jax.numpy as jnp

L = 6 N = 2**L circ = qtn.Circuit(2**L)

for l in range(L): scale = 2**l for i in range(0, 1 << (L - l), 2): circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)

for i in range(1, 1 << (L - l), 2): circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)

#print([item.tags for item in circ.gates]) target = np.random.random(N).tolist()

def loss(U, psi, target, n_qubits): sites = list(range(n_qubits)) exp_vals = [] psi.gate_(U, sites) qc = qtn.Circuit(n_qubits, psi0=psi) loss = 0 i = 0 obs = qu.pauli('Z') while i < n_qubits: expect = qc.local_expectation(obs, i, backend="jax").item().real exp_vals.append(expect) i += 1 for e, t in zip(exp_vals, target): loss += ((e - t) ** 2) loss /= n_qubits return loss

opt = qtn.optimize.TNOptimizer(circ.uni, loss, loss_kwargs={"psi": circ.psi, "target" : target, "n_qubits" : N}, tags=["SU4"], autodiff_backend="jax", optimizer="L-BFGS-B") U_opt = opt.optimize_basinhopping(n=50, nhop=10)

When I try to run the code, I get the following error traceback chain:

` 0%| | 0/500 [00:00<?, ?it/s]WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

0%| | 0/500 [00:00<?, ?it/s] Traceback (most recent call last): File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 48, in U_opt = opt.optimize_basinhopping(n=50, nhop=10) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 946, in optimize_basinhopping self.res = basinhopping( File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 678, in basinhopping bh = BasinHoppingRunner(x0, wrapped_minimizer, take_step_wrapped, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 72, in init minres = minimizer(self.x) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 284, in call return self.minimizer(self.func, x0, *self.kwargs) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_minimize.py", line 617, in minimize return _minimize_lbfgsb(fun, x0, args, jac, bounds, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/lbfgsb.py", line 306, in _minimize_lbfgsb sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 261, in _prepare_scalar_function sf = ScalarFunction(fun, x0, args, grad, hess, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 76, in init self._update_fun() File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 166, in _update_fun self._update_fun_impl() File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 73, in update_fun self.f = fun_wrapped(self.x) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 70, in fun_wrapped return fun(x, args) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 74, in call self._compute_if_needed(x, args) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 68, in _compute_if_needed fg = self.fun(x, args) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 846, in vectorized_value_and_grad ag_result, ag_grads = self.handler.value_and_grad(arrays) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 228, in value_and_grad loss, grads = self._value_and_grad(arrays) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 695, in call return self.loss_fn(self.norm_fn(tn_compute)) File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 32, in loss expect = qc.local_expectation(obs, i, backend="jax").item().real File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1490, in local_expectation rho = self.get_rdm_lightcone_simplified( File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1171, in get_rdm_lightcone_simplified ket_lc = self.get_psi_reverse_lightcone(where) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1052, in get_psi_reverselightcone psi = self.psi File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 946, in psi psi.astype(psi.dtype) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/tensor_core.py", line 5428, in dtype return common_type(*self) File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in common_type dtypes = {array.dtype.name for array in arrays} File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in dtypes = {array.dtype.name for array in arrays} jax._src.traceback_util.FilteredStackTrace: AttributeError: 'str' object has no attribute 'name'

The stack trace above excludes JAX-internal frames. The following is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 48, in U_opt = opt.optimize_basinhopping(n=50, nhop=10) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 946, in optimize_basinhopping self.res = basinhopping( File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 678, in basinhopping bh = BasinHoppingRunner(x0, wrapped_minimizer, take_step_wrapped, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 72, in init minres = minimizer(self.x) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 284, in call return self.minimizer(self.func, x0, self.kwargs) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_minimize.py", line 617, in minimize return _minimize_lbfgsb(fun, x0, args, jac, bounds, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/lbfgsb.py", line 306, in _minimize_lbfgsb sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 261, in _prepare_scalar_function sf = ScalarFunction(fun, x0, args, grad, hess, File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 76, in init self._update_fun() File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 166, in _update_fun self._update_fun_impl() File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 73, in update_fun self.f = fun_wrapped(self.x) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 70, in fun_wrapped return fun(x, args) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 74, in call self._compute_if_needed(x, args) File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 68, in _compute_if_needed fg = self.fun(x, args) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 846, in vectorized_value_and_grad ag_result, ag_grads = self.handler.value_and_grad(arrays) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 228, in value_and_grad loss, grads = self._value_and_grad(arrays) File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, kwargs) File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 371, in f_jitted return cpp_jitted_f(*args, kwargs) File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 278, in cache_miss out_flat = xla.xla_call( File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1229, in bind return call_bind(self, fun, *args, *params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1220, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1232, in process return trace.process_call(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 598, in process_call return primitive.impl(f, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/xla.py", line 569, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 251, in memoized_fun ans = call(fun, args) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/xla.py", line 645, in _xla_callable jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(in_tracers) File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 160, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(args, *kwargs) File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 769, in value_and_grad_f ans, vjp_py = _vjp(f_partial, dyn_args) File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 1846, in _vjp out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py", line 114, in vjp out_primals, pvals, jaxpr, consts = linearize(traceable, primals) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py", line 101, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 516, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 160, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 695, in call return self.loss_fn(self.norm_fn(tn_compute)) File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 32, in loss expect = qc.local_expectation(obs, i, backend="jax").item().real File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1490, in local_expectation rho = self.get_rdm_lightcone_simplified( File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1171, in get_rdm_lightcone_simplified ket_lc = self.get_psi_reverse_lightcone(where) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1052, in get_psi_reverselightcone psi = self.psi File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 946, in psi psi.astype(psi.dtype) File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/tensor_core.py", line 5428, in dtype return common_type(*self) File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in common_type dtypes = {array.dtype.name for array in arrays} File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in dtypes = {array.dtype.name for array in arrays} AttributeError: 'str' object has no attribute 'name'`

The problem seems to be in the call to the local_expectation method, and I am wondering if there is a correct way to carry out what I am trying to achieve in this example. The example that I am showing is a toy example for a more complicated task, but the main objective that I am trying to achieve right now for that is to optimize a circuit using the jax backend with a loss function that calls the local_expectation method. Any help with that would be greatly appreciated.

jcmgray commented 3 years ago

Hi @DarthMalloc. If possible could you tidy the code up so I can copy, paste and run? You might need the syntax ```python <long code> ```

Having said that I can envisage a few problems:

Are you just trying to optimize the first declared circuit with respect to local observables? Have you tried getting something working without the autodiff function first? I.e. just a call to loss on its own.

Sorry I think I need to raise a few more informative errors for address. Also adding functionality to supply a Circuit instance directly to the TNOptimizer would alleviate a lot of complexity I imagine - on my to-do list.

JustinS6626 commented 3 years ago

Hi jcmgray,

Thank you very much for getting back to me. Here is the latest version of my code with better formatting:

import quimb as qu
import quimb.tensor as qtn
import numpy as np
import jax.numpy as jnp
#import cotengra as ctg

L = 6
N = 2**L
circ = qtn.Circuit(2**L)
dtype = "float64"

for l in range(L):
    scale = 2**l
    for i in range(0, 1 << (L - l), 2):
        circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)

    for i in range(1, 1 << (L - l), 2):
        circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)

#print([item.tags for item in circ.gates])
target = np.random.random(N).tolist()
qc = qtn.Circuit(N)
#qc.psi = circ.psi

##opt = ctg.ReusableHyperOptimizer(
##    progbar=True,
##    reconf_opts={},
##    max_repeats=16
##    )

def loss(U, psi, circuit, target, n_qubits):
    sites = list(range(n_qubits))
    exp_vals = []
    psi.gate_(U, sites)
    circuit.psi0 = psi
    loss = 0
    i = 0
    obs = qu.pauli('Z')
    while i < n_qubits:
        expect = circuit.local_expectation(obs, i, dtype=dtype)
        exp_vals.append(expect)
        i += 1
    for e, t in zip(exp_vals, target):
        loss += ((e - t) ** 2)
    loss /= n_qubits
    return loss

opt = qtn.optimize.TNOptimizer(circ.uni,
                               loss,
                               loss_kwargs={"psi": circ.psi,
                                            "circuit" : qc,
                                            "target" : target,
                                            "n_qubits" : N},
                               tags=["SU4"],
                               autodiff_backend="torch",
                               optimizer="L-BFGS-B",
                               jit_fn=True)
U_opt = opt.optimize_basinhopping(n=50, nhop=10)

#circ.psi.draw(color=circ.psi.site_tags, legend=False)

It seems that the loss function needs the unitary form as a positional argument, since the circuit itself does not appear to inherit the base tensor network class. Based on that, what I am trying to do is to evolve an arbitrary initial ket state with the unitary of the circuit, and then plug the evolved ket state into a dummy circuit object so that I can perform local expectations on it, all within the loss function. I am wondering what would be the exactly correct way to do this, since the optimizer seems to be very particular about which object types and operations it will allow in the loss function, based on which autodiff_backend it uses. On a related note, I would really appreciate if you would be able to tell me which object types and function calls are forbidden in the loss function if jax is being used as the autodiff_backend.

JustinS6626 commented 3 years ago

I just wanted to follow up with you on this issue. The problem that I am having with the optimizer is holding up a major project and I would really like to get that working in the next day or so so I can have experimental results this week. If I am using the TNOptimizer class with the jax backend, which functions and object classes with the optimizer not permit? I am sorry for any imposition, and I greatly appreciate your help.

jcmgray commented 3 years ago

Thank you for bearing with the delay.

Regarding objects passed to TNOptimizer, I think this is mostly dealt with in the docstring. quimb functions mostly just orchestrate array operations so there is no imposed restrictions on these either, that being said, there needs to be a clean flow of computation from the input tensor network to the output (e.g. no conditional behaviour), involving only array operations the backend supports. The arrays you want to diff w.r.t need to be in the first argument, and other numeric arguments need to be in loss_constants so that they can be found and converted to the correct backend, (although jax and others are happy mixing numpy arrays into the computational graph as constants.)

The problem here is U is a TensorNetwork being passed to qtn.gate_TN_1D, which is meant to take a single dense array. TensorNetwork have a shape attribute which makes them look array like enough to be wrapped in a Tensor object which is the issue here. Then the arrays you want to differentiate are hidden within a tensor network within a tensor within a tensor network, so likely they are not found as the arrays to differentiate with respect to, and I'm surprised the function works even outside of the autodiff tracer.

The band aid fix here is just to combine the psi and uni properly into a new TN by aligning their indices rather than using gate, but can you clarify your overall goal here? I think there is probably a more efficient way to set it up:

  1. Take an (arbitrary? always another quantum circuit?) input state
  2. Apply another (different?) circuit
  3. Optimize the parameters of both circuits, just the first/second circuit? with respect to local expectations (presumably eventually not just 1-local operators?)

Would the following (so far not implemented..) usage be sufficient?

def loss(circ, ops, targets):
    expecs = [
        circ.local_expectation(where, op)
        for where, op in ops.items()
    ]
    return sum(
        (e - t)**2 for e, t in zip(exepcs, targets)
    ) 

tnopt = qtn.TNOptimizer(
    circ,
    loss=loss,
    tags='SU4',
)

circ_opt = tnopt.optimize(1000) 
jcmgray commented 3 years ago

For reference here is how one can currently optimize the wavefunction generated by a Circuit w.r.t to local expectations:

import quimb as qu
import quimb.tensor as qtn
from autoray import do
import functools

# setup a circuit
N = 5
depth = 3
circ = qtn.circ_ansatz_1D_brickwork(N, depth, )

# the local operators
ops = {(i, i + 1): qu.ham_heis(2) for i in range(N - 1)}

# tags describing all the different reverse lightcones for each operators
lightcone_tags = {where: circ.get_reverse_lightcone_tags(where) for where in ops}

# function that takes the input TN and computes the ith loss
def loss_i(psi, where, ops):
    tags = lightcone_tags[where]
    ket  = psi.select(tags, 'any')
    bra = ket.H
    expec = ket.gate(ops[where], where) | bra
    return do('real', expec.contract(all, optimize='auto-hq'))

# since they are a sum we can evaluate them separately
loss_fns = [
    functools.partial(loss_i, where=where)
    for where in ops
]

tnopt = qtn.TNOptimizer(
    circ.psi,
    loss_fn=loss_fns,
    tags='U3',
    loss_constants={'ops': ops},
    autodiff_backend='jax',
)

tnopt.optimize(100)
# -1.9249370098114014: 100%|██████████| 100/100 [00:32<00:00,  3.05it/s]
# <TensorNetwork1DVector(tensors=58, indices=70, L=5, max_bond=2)>
JustinS6626 commented 3 years ago

Thank you very much! That helps a lot. For now, I just have two more questions. First, I am wondering if the restriction against conditionals applies to loops as well. Also, I am wondering if the "where" parameter in the loss_i function can include the entire range of sites in the system.

jcmgray commented 3 years ago

Loops are fine as long as they don't involve conditionals, see this section specfic to jax that is roughly true for most autodiff backends - https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow. Large loops can make compilation very expensive however.

And yes you could put where as the entire state, but you would pretty much lose any advantage to using tensor networks with unitary structure. Moreover, currently you'd need to supply the gate G as a size 2**(2 * len(where)) array which is not going to scale well... What are you trying to compute? For it to make sense to use MERA etc, in general one expects k-local operators with 2 <= k << N.

JustinS6626 commented 3 years ago

Thank you very much! Once again, that helps me a great deal. Just one thing that I wanted to clarify: I noticed that you set up the loss function parameters as a list with the results to be called using the autoray.do method. Is that particular approach necessary, or is there some flexibility?

jcmgray commented 3 years ago

Just to explain those choices:

JustinS6626 commented 3 years ago

Once again, thank you very much. That really helps a lot. I tried out a variation of the approach that you showed in the code that you posted last week, but I am still having a few challenges with it. I was therefore wondering if the additional numerical arguments need to be given in the loss_constants dictionary, or if they may be given in the loss_kwargs dictionary. Second, since the optimizer normally suppresses outputs given to stdout, I was wondering if there is a way to get the outputs to show so that I can get benchmarking information from the loss function.

jcmgray commented 3 years ago

I was therefore wondering if the additional numerical arguments need to be given in the loss_constants dictionary, or if they may be given in the loss_kwargs dictionary.

Constant numerical arguments should be the loss_constants dict, though some autodiff backends with tolerate them in the loss_kwargs dict, i.e. without converting from numpy.

Second, since the optimizer normally suppresses outputs given to stdout, I was wondering if there is a way to get the outputs to show so that I can get benchmarking information from the loss function.

You can turn the optimizer's progbar off with progbar=False if its cluttering the output but I'm not aware of any output suppression that takes place otherwise. Do you have an example? Note that in a notebook/jupyterlab some output streams are printed to console (i.e. where you called jupyter-lab).

JustinS6626 commented 3 years ago

Thank you very much for getting back to me about that. In terms of output from the optimizer, what I am trying to do is to use regular print commands in the scope of the loss function to output strings to give diagnostic information. Should print commands work normally inside the loss function when the optimizer uses it, or is there another method that I need to use?

jcmgray commented 3 years ago

Ah sorry I realise the problem is probably that you are jit compiling the expression, what that does is trace through just the numeric expressions and compile a 'non-python' version that just reproduces the numeric functions. See this discussion here - https://github.com/google/jax/issues/196. Some backends do actually re-run the python function (torch I think).

Does the diagnostic information you require need to be gathered during the actual gradient descent?

JustinS6626 commented 3 years ago

I managed to get the optimization process working. Thank you again very much for your help!