Closed JustinS6626 closed 3 years ago
Hi @DarthMalloc. If possible could you tidy the code up so I can copy, paste and run? You might need the syntax
```python
<long code>
```
Having said that I can envisage a few problems:
psi.gate_(U, sites)
- I think gate
is not meant to take en entire TensorNetwork2DOperator
- not quite sure what will end up happening hereuni
and psi
then manually adding the uni
to psi
again (so U U psi0
) then putting it in as the initial state, not sure if this is intended. And also the clever lightcone cancelling stuff in Circuit
only knows about the gates it applies, not the input state.Are you just trying to optimize the first declared circuit with respect to local observables? Have you tried getting something working without the autodiff function first? I.e. just a call to loss
on its own.
Sorry I think I need to raise a few more informative errors for address. Also adding functionality to supply a Circuit
instance directly to the TNOptimizer
would alleviate a lot of complexity I imagine - on my to-do list.
Hi jcmgray,
Thank you very much for getting back to me. Here is the latest version of my code with better formatting:
import quimb as qu
import quimb.tensor as qtn
import numpy as np
import jax.numpy as jnp
#import cotengra as ctg
L = 6
N = 2**L
circ = qtn.Circuit(2**L)
dtype = "float64"
for l in range(L):
scale = 2**l
for i in range(0, 1 << (L - l), 2):
circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)
for i in range(1, 1 << (L - l), 2):
circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)
#print([item.tags for item in circ.gates])
target = np.random.random(N).tolist()
qc = qtn.Circuit(N)
#qc.psi = circ.psi
##opt = ctg.ReusableHyperOptimizer(
## progbar=True,
## reconf_opts={},
## max_repeats=16
## )
def loss(U, psi, circuit, target, n_qubits):
sites = list(range(n_qubits))
exp_vals = []
psi.gate_(U, sites)
circuit.psi0 = psi
loss = 0
i = 0
obs = qu.pauli('Z')
while i < n_qubits:
expect = circuit.local_expectation(obs, i, dtype=dtype)
exp_vals.append(expect)
i += 1
for e, t in zip(exp_vals, target):
loss += ((e - t) ** 2)
loss /= n_qubits
return loss
opt = qtn.optimize.TNOptimizer(circ.uni,
loss,
loss_kwargs={"psi": circ.psi,
"circuit" : qc,
"target" : target,
"n_qubits" : N},
tags=["SU4"],
autodiff_backend="torch",
optimizer="L-BFGS-B",
jit_fn=True)
U_opt = opt.optimize_basinhopping(n=50, nhop=10)
#circ.psi.draw(color=circ.psi.site_tags, legend=False)
It seems that the loss function needs the unitary form as a positional argument, since the circuit itself does not appear to inherit the base tensor network class. Based on that, what I am trying to do is to evolve an arbitrary initial ket state with the unitary of the circuit, and then plug the evolved ket state into a dummy circuit object so that I can perform local expectations on it, all within the loss function. I am wondering what would be the exactly correct way to do this, since the optimizer seems to be very particular about which object types and operations it will allow in the loss function, based on which autodiff_backend it uses. On a related note, I would really appreciate if you would be able to tell me which object types and function calls are forbidden in the loss function if jax is being used as the autodiff_backend.
I just wanted to follow up with you on this issue. The problem that I am having with the optimizer is holding up a major project and I would really like to get that working in the next day or so so I can have experimental results this week. If I am using the TNOptimizer class with the jax backend, which functions and object classes with the optimizer not permit? I am sorry for any imposition, and I greatly appreciate your help.
Thank you for bearing with the delay.
Regarding objects passed to TNOptimizer
, I think this is mostly dealt with in the docstring. quimb
functions mostly just orchestrate array operations so there is no imposed restrictions on these either, that being said, there needs to be a clean flow of computation from the input tensor network to the output (e.g. no conditional behaviour), involving only array operations the backend supports. The arrays you want to diff w.r.t need to be in the first argument, and other numeric arguments need to be in loss_constants
so that they can be found and converted to the correct backend, (although jax
and others are happy mixing numpy arrays into the computational graph as constants.)
The problem here is U
is a TensorNetwork
being passed to qtn.gate_TN_1D
, which is meant to take a single dense array. TensorNetwork
have a shape attribute which makes them look array like enough to be wrapped in a Tensor
object which is the issue here. Then the arrays you want to differentiate are hidden within a tensor network within a tensor within a tensor network, so likely they are not found as the arrays to differentiate with respect to, and I'm surprised the function works even outside of the autodiff tracer.
The band aid fix here is just to combine the psi
and uni
properly into a new TN by aligning their indices rather than using gate
, but can you clarify your overall goal here? I think there is probably a more efficient way to set it up:
Would the following (so far not implemented..) usage be sufficient?
def loss(circ, ops, targets):
expecs = [
circ.local_expectation(where, op)
for where, op in ops.items()
]
return sum(
(e - t)**2 for e, t in zip(exepcs, targets)
)
tnopt = qtn.TNOptimizer(
circ,
loss=loss,
tags='SU4',
)
circ_opt = tnopt.optimize(1000)
For reference here is how one can currently optimize the wavefunction generated by a Circuit
w.r.t to local expectations:
import quimb as qu
import quimb.tensor as qtn
from autoray import do
import functools
# setup a circuit
N = 5
depth = 3
circ = qtn.circ_ansatz_1D_brickwork(N, depth, )
# the local operators
ops = {(i, i + 1): qu.ham_heis(2) for i in range(N - 1)}
# tags describing all the different reverse lightcones for each operators
lightcone_tags = {where: circ.get_reverse_lightcone_tags(where) for where in ops}
# function that takes the input TN and computes the ith loss
def loss_i(psi, where, ops):
tags = lightcone_tags[where]
ket = psi.select(tags, 'any')
bra = ket.H
expec = ket.gate(ops[where], where) | bra
return do('real', expec.contract(all, optimize='auto-hq'))
# since they are a sum we can evaluate them separately
loss_fns = [
functools.partial(loss_i, where=where)
for where in ops
]
tnopt = qtn.TNOptimizer(
circ.psi,
loss_fn=loss_fns,
tags='U3',
loss_constants={'ops': ops},
autodiff_backend='jax',
)
tnopt.optimize(100)
# -1.9249370098114014: 100%|██████████| 100/100 [00:32<00:00, 3.05it/s]
# <TensorNetwork1DVector(tensors=58, indices=70, L=5, max_bond=2)>
Thank you very much! That helps a lot. For now, I just have two more questions. First, I am wondering if the restriction against conditionals applies to loops as well. Also, I am wondering if the "where" parameter in the loss_i function can include the entire range of sites in the system.
Loops are fine as long as they don't involve conditionals, see this section specfic to jax
that is roughly true for most autodiff backends - https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow. Large loops can make compilation very expensive however.
And yes you could put where
as the entire state, but you would pretty much lose any advantage to using tensor networks with unitary structure. Moreover, currently you'd need to supply the gate G
as a size 2**(2 * len(where))
array which is not going to scale well... What are you trying to compute? For it to make sense to use MERA etc, in general one expects k-local operators with 2 <= k << N
.
Thank you very much! Once again, that helps me a great deal. Just one thing that I wanted to clarify: I noticed that you set up the loss function parameters as a list with the results to be called using the autoray.do method. Is that particular approach necessary, or is there some flexibility?
Just to explain those choices:
autoray.do
is so that whatever autodiff_backend
is being used, the 'real'
function dispatches correctly, if you were only interested in using jax
e.g., you could use a backend specific function like jax.numpy.real
Once again, thank you very much. That really helps a lot. I tried out a variation of the approach that you showed in the code that you posted last week, but I am still having a few challenges with it. I was therefore wondering if the additional numerical arguments need to be given in the loss_constants
dictionary, or if they may be given in the loss_kwargs
dictionary. Second, since the optimizer normally suppresses outputs given to stdout, I was wondering if there is a way to get the outputs to show so that I can get benchmarking information from the loss function.
I was therefore wondering if the additional numerical arguments need to be given in the loss_constants dictionary, or if they may be given in the loss_kwargs dictionary.
Constant numerical arguments should be the loss_constants
dict, though some autodiff backends with tolerate them in the loss_kwargs
dict, i.e. without converting from numpy.
Second, since the optimizer normally suppresses outputs given to stdout, I was wondering if there is a way to get the outputs to show so that I can get benchmarking information from the loss function.
You can turn the optimizer's progbar off with progbar=False
if its cluttering the output but I'm not aware of any output suppression that takes place otherwise. Do you have an example? Note that in a notebook/jupyterlab some output streams are printed to console (i.e. where you called jupyter-lab
).
Thank you very much for getting back to me about that. In terms of output from the optimizer, what I am trying to do is to use regular print
commands in the scope of the loss function to output strings to give diagnostic information. Should print
commands work normally inside the loss function when the optimizer uses it, or is there another method that I need to use?
Ah sorry I realise the problem is probably that you are jit compiling the expression, what that does is trace through just the numeric expressions and compile a 'non-python' version that just reproduces the numeric functions. See this discussion here - https://github.com/google/jax/issues/196. Some backends do actually re-run the python function (torch I think).
Does the diagnostic information you require need to be gathered during the actual gradient descent?
I managed to get the optimization process working. Thank you again very much for your help!
I have built the circuit structure that you showed in
https://github.com/jcmgray/quimb/issues/76
and what I am trying to do now is to optimize the circuit using the TNOptimizer class with basin-hopping.
The code for my approach is as follows:
import quimb as qu
import quimb.tensor as qtn
import numpy as np
import jax.numpy as jnp
L = 6
N = 2**L
circ = qtn.Circuit(2**L)
for l in range(L):
scale = 2**l
for i in range(0, 1 << (L - l), 2):
circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)
for i in range(1, 1 << (L - l), 2):
circ.su4(*qu.randn(15), scale * i, scale * (i + 1) % N)
#print([item.tags for item in circ.gates])
target = np.random.random(N).tolist()
def loss(U, psi, target, n_qubits):
sites = list(range(n_qubits))
exp_vals = []
psi.gate_(U, sites)
qc = qtn.Circuit(n_qubits, psi0=psi)
loss = 0
i = 0
obs = qu.pauli('Z')
while i < n_qubits:
expect = qc.local_expectation(obs, i, backend="jax").item().real
exp_vals.append(expect)
i += 1
for e, t in zip(exp_vals, target):
loss += ((e - t) ** 2)
loss /= n_qubits
return loss
opt = qtn.optimize.TNOptimizer(circ.uni,
loss,
loss_kwargs={"psi": circ.psi,
"target" : target,
"n_qubits" : N},
tags=["SU4"],
autodiff_backend="jax",
optimizer="L-BFGS-B")
U_opt = opt.optimize_basinhopping(n=50, nhop=10)
When I try to run the code, I get the following error traceback chain:
` 0%| | 0/500 [00:00<?, ?it/s]WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0%| | 0/500 [00:00<?, ?it/s] Traceback (most recent call last): File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 48, in
U_opt = opt.optimize_basinhopping(n=50, nhop=10)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 946, in optimize_basinhopping
self.res = basinhopping(
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 678, in basinhopping
bh = BasinHoppingRunner(x0, wrapped_minimizer, take_step_wrapped,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 72, in init
minres = minimizer(self.x)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 284, in call
return self.minimizer(self.func, x0, *self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_minimize.py", line 617, in minimize
return _minimize_lbfgsb(fun, x0, args, jac, bounds,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/lbfgsb.py", line 306, in _minimize_lbfgsb
sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 261, in _prepare_scalar_function
sf = ScalarFunction(fun, x0, args, grad, hess,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 76, in init
self._update_fun()
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 166, in _update_fun
self._update_fun_impl()
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 73, in update_fun
self.f = fun_wrapped(self.x)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 70, in fun_wrapped
return fun(x, args)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 74, in call
self._compute_if_needed(x, args)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 68, in _compute_if_needed
fg = self.fun(x, args)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 846, in vectorized_value_and_grad
ag_result, ag_grads = self.handler.value_and_grad(arrays)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 228, in value_and_grad
loss, grads = self._value_and_grad(arrays)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 695, in call
return self.loss_fn(self.norm_fn(tn_compute))
File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 32, in loss
expect = qc.local_expectation(obs, i, backend="jax").item().real
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1490, in local_expectation
rho = self.get_rdm_lightcone_simplified(
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1171, in get_rdm_lightcone_simplified
ket_lc = self.get_psi_reverse_lightcone(where)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1052, in get_psi_reverselightcone
psi = self.psi
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 946, in psi
psi.astype(psi.dtype)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/tensor_core.py", line 5428, in dtype
return common_type(*self)
File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in common_type
dtypes = {array.dtype.name for array in arrays}
File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in
dtypes = {array.dtype.name for array in arrays}
jax._src.traceback_util.FilteredStackTrace: AttributeError: 'str' object has no attribute 'name'
The stack trace above excludes JAX-internal frames. The following is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 48, in
U_opt = opt.optimize_basinhopping(n=50, nhop=10)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 946, in optimize_basinhopping
self.res = basinhopping(
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 678, in basinhopping
bh = BasinHoppingRunner(x0, wrapped_minimizer, take_step_wrapped,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 72, in init
minres = minimizer(self.x)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_basinhopping.py", line 284, in call
return self.minimizer(self.func, x0, self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_minimize.py", line 617, in minimize
return _minimize_lbfgsb(fun, x0, args, jac, bounds,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/lbfgsb.py", line 306, in _minimize_lbfgsb
sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 261, in _prepare_scalar_function
sf = ScalarFunction(fun, x0, args, grad, hess,
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 76, in init
self._update_fun()
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 166, in _update_fun
self._update_fun_impl()
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 73, in update_fun
self.f = fun_wrapped(self.x)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/_differentiable_functions.py", line 70, in fun_wrapped
return fun(x, args)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 74, in call
self._compute_if_needed(x, args)
File "/usr/local/lib/python3.8/dist-packages/scipy/optimize/optimize.py", line 68, in _compute_if_needed
fg = self.fun(x, args)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 846, in vectorized_value_and_grad
ag_result, ag_grads = self.handler.value_and_grad(arrays)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 228, in value_and_grad
loss, grads = self._value_and_grad(arrays)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(args, kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 371, in f_jitted
return cpp_jitted_f(*args, kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 278, in cache_miss
out_flat = xla.xla_call(
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1229, in bind
return call_bind(self, fun, *args, *params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1220, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1232, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 598, in process_call
return primitive.impl(f, tracers, params)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/xla.py", line 569, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, args)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/xla.py", line 645, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(in_tracers)
File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(args, *kwargs)
File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 769, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, dyn_args)
File "/usr/local/lib/python3.8/dist-packages/jax/api.py", line 1846, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py", line 114, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, primals)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 516, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/optimize.py", line 695, in call
return self.loss_fn(self.norm_fn(tn_compute))
File "/home/justin/darthmallocsarchive-svn/trunk/DissertationExperiment/circuit_MERA_test.py", line 32, in loss
expect = qc.local_expectation(obs, i, backend="jax").item().real
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1490, in local_expectation
rho = self.get_rdm_lightcone_simplified(
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1171, in get_rdm_lightcone_simplified
ket_lc = self.get_psi_reverse_lightcone(where)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 1052, in get_psi_reverselightcone
psi = self.psi
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/circuit.py", line 946, in psi
psi.astype(psi.dtype)
File "/usr/local/lib/python3.8/dist-packages/quimb/tensor/tensor_core.py", line 5428, in dtype
return common_type(*self)
File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in common_type
dtypes = {array.dtype.name for array in arrays}
File "/usr/local/lib/python3.8/dist-packages/quimb/core.py", line 293, in
dtypes = {array.dtype.name for array in arrays}
AttributeError: 'str' object has no attribute 'name'`
The problem seems to be in the call to the local_expectation method, and I am wondering if there is a correct way to carry out what I am trying to achieve in this example. The example that I am showing is a toy example for a more complicated task, but the main objective that I am trying to achieve right now for that is to optimize a circuit using the jax backend with a loss function that calls the local_expectation method. Any help with that would be greatly appreciated.