Open nikitn2 opened 1 year ago
Hi @nikitn2, turning on compressed contraction through the usual (exact) contract
interface is still quite experimental. I believe the issue is that when you supply max_bond
it switches to the compressed contraction method, which has a default cutoff
which is non-zero. When using such a cutoff
the generated shapes can be dynamic which tracing libraries like jax
don't like, so you just need to supply cutoff=0.0
.
The general problem here is the difficulty of choosing minimal default arguments/options for these advanced algorithms which don't hide the details. E.g. compressed contraction shouldn't be used with the current default exact optimize='greedy'
etc. args / path optimizers, but it would eventually be nice for it to be a simple switch..
Thanks for the kind words about quimb!
Hi @jcmgray,
Thanks for the reply – your explanation makes perfect sense. It's probably too much to expect an autodifferencing framework to easily handle dynamically-sized arrays, and I can indeed see how difficult it is to integrate these considerations into a numerical library as advanced as yours, while still keeping quimb simple to use. Dilemmas...!
Your cutoff=0.0
trick does indeed fix the issue in the example I provided, however, what about the other compression methods? I need to use tensor_network_apply_op_vec(), or MatrixProductState.compress(), in my code, and in this case your solution doesn't work for me. For example, running the below code snippet will still result in the same error as before:
%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import TNOptimizer
chi = 8
L = 16
builder = qtn.SpinHam1D(S=1)
builder += 1/2, '+', '-'
builder += 1/2, '-', '+'
builder += 1, 'Z', 'Z'
H = builder.build_mpo(L)
bond_dim = 16
mps = qtn.MPS_rand_state(L, bond_dim, phys_dim=3, cyclic=False)
def normalize_state(psi):
return psi / (psi.H @ psi) ** 0.5
def energy(psi, H):
# Alternative 1
Hpsi = qtn.tensor_arbgeom.tensor_network_apply_op_vec(tn_op=H,tn_vec = psi,compress=True, max_bond = chi, cutoff=0.0 )
# Alternative 2
#Hpsi = qtn.tensor_arbgeom.tensor_network_apply_op_vec(tn_op=H,tn_vec = psi,compress=False)
#Hpsi.compress( max_bond=chi,cutoff=0.0)
tn = psi.H & Hpsi
return tn^...
optmzr = TNOptimizer(
mps,
loss_fn=energy,
norm_fn=normalize_state,
loss_constants={'H': H},
autodiff_backend='jax',
optimizer='L-BFGS-B',
)
mps_opt = optmzr.optimize(100)
Is there perhaps a way for me to just globally turn off dynamic shaping of tensors when using autodiff?
Hi @jcmgray,
I've been playing around a bit with this issue in my code and I find that Jax basically doesn't like it when the bond-dimensions are in any way dynamically allocated.
I've found two fixes so far. In my example above, if you pass renorm=0
along with cutoff=0.0
, then the autodiff will run. However, sadly it will also become numerically unstable... Furthermore, in my code in general I also observe that opts['max_bond'] = _MAX_BOND_LOOKUP.get(max_bond, max_bond)
in tensor_core.py causes problems. Changing the line to opts['max_bond'] = -1 if(max_bond==None) else max_bond
seems to mesh better with Jax.
Do you think there is anything I can do to get the Jax autodifferencing to work with compression in a more stable manner?
Just a couple of things:
Is the 1D setup above your main use case? Generally compression is not needed for such cases (e.g. in the above case, contracting an MPO into the MPS and compressing is more expensive than contracting the MPO expectation directly). That is why renorm
is indeed set to 0 for non-1d tensor networks in quimb
as the standard way to contract these is usually fixed bond dimension (and since there is no canonical form you can't rely on the singular values as much for truncation)
Autodiff and compressed contraction is inherently unstable in many cases! That's because unless the approximate contraction is numerically close to the exact contraction, the optimizer learns to exploit the different to produce unphysical results: for example above you are really computing <psi|H'|psi'>
where H'
might not be hermitian and psi'
might differ from psi
etc.
Hi @jcmgray,
Thanks very much for the reply! You're right, it does indeed make perfect sense that the above example should be numerically unstable.
Though the above example is not my main use case. My main use case involves applying operators O = O(|psi>)
with very large bond-dimensions on wavefunctions |psi>
to compute a loss function of the type L = | O|psi> - |phi> |^2 = <O psi | O psi> + <phi | phi> - <O psi| phi> - <phi | O psi>
. In my code, O
is a tensor network operator (but I could contract it into MPO form) and |psi>
and |phi>
are either in MPS or tree tensor network form.
My first idea was to contract O
into a MPO, then use the zip-up algorithm to compute the |O psi>
terms in a reasonably cheap manner, after which L
can be calculated by just adding up overlaps of various TNs. However, Jax didn't like my zipup implementation... It gave the error:
File "/Users/ngourianov/opt/anaconda3/lib/python3.9/site-packages/quimb/tensor/tensor_core.py", line 317, in _parse_split_opts
opts['max_bond'] = _MAX_BOND_LOOKUP.get(max_bond, max_bond)
TypeError: unhashable type: 'DynamicJaxprTracer'
After applying the fix I mentioned in my previous post, I now get a Jax "ConcretizationTypeError", originating at the elif max_bond > 0:
line in the _trim_and_renorm_SVD
function. Not really sure how to proceed with that.
Giving up on this, I tried just computing the four terms of L
using TensorNetwork.contract() with compression switched on, in the hope that it'd be able to compute L
with the same complexity as if I used the zip-up algorithm. However, this also doesn't work. Even when I set cutoff=0.0
, I get the error:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function <unknown> for jit. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = lt b c
from line /Users/nik/opt/anaconda3/lib/python3.9/site-packages/quimb/tensor/tensor_core.py:6149 (_compress_neighbors)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The only way I can compute and minimise L
using Jax is by using TensorNetwork.contract() without any compression. But this isn't really scalable due to the large bond-dimension of O
...
Do you have any idea what I could do to proceed? I tried using tensorflow instead of Jax, and it actually worked with compression turned on, albeit it's almost two orders of magnitude slower than Jax...
EDIT:
I'm thinking of just restricted O
to be a MPO and |psi>
and |phi>
to be MPSs. Presumably this should be simple enough for the compressed contraction to work. Would the compressed contraction algorithm in this case just automatically implement the zipup algorithm?
I see, so its a fitting task (quimb
does actually have TN fit functionality, but probably for full control you might want to handle things manually).
After applying the fix I mentioned in my previous post, I now get a Jax "ConcretizationTypeError", originating at the elif max_bond > 0: line in the _trim_and_renorm_SVD function. Not really sure how to proceed with that.
This is saying that max_bond
is being dynamically set somewhere so the concrete value is not available when tracing (which is also why its not hashable, changing from the dict lookup does probably makes sense too however). I suppose you need to check the algorithm to make sure max_bond
is always constant.
The only way I can compute and minimise L using Jax is by using TensorNetwork.contract() without any compression. But this isn't really scalable due to the large bond-dimension of O...
Are you sure the complexity of even the zip up algrorithm is not similar as exact contraction <phi|O|psi>
? Since I guess you are not optimizing O itself (?) you don't need the constant O^2 overlap term. Or is it the memory overhead from back-propagation that is too much?
I'm thinking of just restricted O to be a MPO and |psi> and |phi> to be MPSs. Presumably this should be simple enough for the compressed contraction to work. Would the compressed contraction algorithm in this case just automatically implement the zipup algorithm?
What the compressed contraction algorithm does depends entirely on the contraction path (optimize
kwarg). For 1D and tree like TNs there is no compression required as you can contract exactly without increasing the intermediates size. My understanding for the 'apply mpo to mps and compress' algorithms is that these are useful when you are applying several or many MPOs, i.e. you really have a 2D geometry effectively.
That being said, if the MPO has a really large bond dimension, maybe that's different enough from '1D and tree like' to apply some compression somewhere, you could run a cotengra
HyperCompressedOptimizer
to search. But I would stress that this stuff is not 'officially' supported in quimb yet.
HI @jcmgray,
Thanks so much for your reply!
Are you sure the complexity of even the zip up algrorithm is not similar as exact contraction <phi|O|psi>? Since I guess you are not optimizing O itself (?) you don't need the constant O^2 overlap term. Or is it the memory overhead from back-propagation that is too much?
It's a bit worse than that I'm afraid, as I'm basically dealing with a highly nonlinear problem. In my case O
is itself a sum of operators O_1, O_2, ...
with one of the operators, let's say O_N
, having very high bond-dimension and depending on the variational function itself,O_N = O_N (|psi>)
. So I can't really use your fit function, unfortunately. And I also do think I need to calculate the overlap term< psi | O' O | psi >
when calculating the loss function L, unless there's something I've missed?
Therefore, even if I represent |psi>
as a MPS and O
as an MPO, the resulting loss function L = < psi | O' O psi> + <phi|phi> - < psi | O' phi > - <phi| O psi>
will still be a “2D geometry”, as you put it, which is why I'd like to use compressed contraction.
So to caculate L
, it would be nice to first pre-comute |O psi > = O |psi> = O_1 |psi> ... O_N |psi>
using the zipup algorithm such that the bond-dimension is kept in check. Do you have any ideas I could try to fix the ConcretizationTypeError
when I try to use zipup ?
That being said, if the MPO has a really large bond dimension, maybe that's different enough from '1D and tree like' to apply some compression somewhere, you could run a
cotengra HyperCompressedOptimizer
to search. But I would stress that this stuff is not 'officially' supported in quimb yet.
I'll try to investigate this contengra HyperCompressedOptimizer
concept. Thank you for telling me about it.
Thanks again for your replies – you've no idea how much time they save for me, and for that I'm incredibly grateful :)
... And I also do think I need to calculate the overlap term < psi | O' O | psi > when calculating the loss function L, unless there's something I've missed?
I see, yes I just meant if you were only interested in finding $\min_{\phi} ||\phi\rangle - O | \psi \rangle|$, then the $\langle \psi | O^{\dagger} O | \psi \rangle$ term is constant & doesn't figure in the optimization, though you might still want it for the actual value of $L$.
... Do you have any ideas I could try to fix the ConcretizationTypeError when I try to use zipup ?
Only that you need to find where your implementation of the algorithm calls tensor_split
with cutoff != 0.0 or max_bond is None
, possibly there is a call to MatrixProductState.compress
or some such function, personally I'd just set a breakpoint with the above condition to find it. You could also try torch, though eventual performance might not be as good.
Ah, I see now the confusion re the Loss function: my variational function is Psi, not Phi! Sorry, I forgot to mention that.
So I need to always hard–set the bond dimension, or it might perform compression even if cutoff=0.0? That makes sense, actually.
Thanks a lot!
What is your issue?
Hello,
Lately I've been trying to use jax-autodifferencing to minimise a certain loss-function of mine, albeit I keep running into errors like “unhashable type: 'DeviceArray'” or “ConcretizationTypeError”. They happen when I try to use compression or tensor.split(), which makes me think that JAX has some sort of incompatibility with SVDs, since SVDs are typically used to compress bonds and split tensors.
Below is a reproduction of the issue based on Chapter 9 in your user guide. By setting chi = None, this code will run, but when chi = 8 and compression is performed, the code crashes with the following error message:
I find this particularly strange given that the code in Chapter 4.8 works just fine for me despite it also making use of jax-autodifferencing along with compression (max_bond is limited to 32 in compute_local_expectation() ).
Do you have any idea what might be wrong?
And thanks very much for this excellent numerical library – I LOVE quimb!
Cheers