jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
455 stars 107 forks source link

Converting between Tensor Network and MatrixProductState #168

Closed sebulo closed 1 year ago

sebulo commented 1 year ago

What is your issue?

Hi, I am experiencing that the Matrix Product State (MPS) method from_TN() changes the underlying structure of the Tensor Network (TN), and I cannot seem to find out how the structure is changed. Here is an example of my problem:

If I now contract both the TN and MPS and reshape them to the original image shape, I get two different results. The TN reconstructs the image correctly while the MPS gives a result that looks like a permuted version of the original . However, the cores of the TN and the MPS are identical. I tried looking at get_equation() to see if the contraction schemes differ, but this is not the case.

Can anyone help me explain what is being changed when calling from_TN()?

I have the code to reproduce the example I just explained in a jupyter notebook.

!pip install --no-deps -U git+https://github.com/jcmgray/quimb.git@develop
!pip install torch
import quimb.tensor as qtn
import torch
import matplotlib.pyplot as plt
import numpy as np
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
input_dim = 64
image = np.arange(1, input_dim**2 +1).reshape(input_dim,input_dim)
image = torch.from_numpy(image).float()

dim_tn = int(np.log2(input_dim)) * [4] 

# get MPS
mps = qtn.MatrixProductState.from_dense(image, dims=dim_tn)
mps.draw()
# get same Tensor network
tn = qtn.TensorNetwork(mps.tensors)
tn.draw()
# convert back to MPS
mps = qtn.MatrixProductState.from_TN(tn, L=len(dim_tn), cyclic=False, site_tag_id='I{}', site_ind_id='k{}')
mps.draw()

# Check cores of the TN and MPS are identical
for i in range(len(mps.tensors)):
    assert (tn.tensors[i].data-mps.tensors[i].data).sum().item() == 0.

# Get image representation from TN
tn = tn.contract()
tn_data = tn.data.reshape(input_dim,input_dim)
# Get image representation from MPS
mps = mps.contract()
mps_data = mps.data.reshape(input_dim,input_dim)

# Plot outputs
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(image, cmap="gray")
plt.title("Ground truth")
plt.subplot(1, 3, 2)
plt.imshow(tn_data, cmap="gray")
plt.title("Recon image TN")
plt.subplot(1, 3, 3)
plt.imshow(mps_data, cmap="gray")
plt.title("Recon image MPS")

Output: 1

Upon further inspection, if changing the installment of quimb from the developer version to a stable version with "pip install quimb", I get the following result instead: 2

If I use get_equation() on the MPS I can see that the developer version gives the following result:

If I reverse permute the output after contracting:

# Get image representation from TN
tn = tn.contract()
tn_data = tn.data
tn_data = tn_data.permute(5,4,3,2,1,0) # INSERTED reverse permutation
tn_data = tn_data.reshape(input_dim,input_dim)
# Get image representation from MPS
mps = mps.contract()
mps_data = mps.data
mps_data = mps_data.permute(5,4,3,2,1,0) # INSERTED reverse permutation
mps_data=mps_data.reshape(input_dim,input_dim)

Then I get this 3

Can anyone help me explain what the difference is between the developer and the stable version of quimb?

jcmgray commented 1 year ago

Hi @Sloeschcke, I suspect the issue is just that when go between TN representation and dense form you need to explicitly give the ordering of the indices. Once something is a tensor network, there is no definitive ordering of the indices or tensors - you need to specify it.

tn.contract() by default chooses the output indices as the order in which they appear on the tensors (which apparently is different for tn and mps), it doesn't matter so much if you keep the result as a tensor/network, since everything is labelled. But when you call .data you are forgetting all the labels. The answer is hopefully:

  1. specify contract(output_inds=[...]), to ensure a particular ordering of the indices.
  2. use to_dense(['k0', 'k1', 'k2', 'k3'], ['k4', 'k5', 'k6', 'k7']) (if that is your particular encoding), which handles both the contraction and reshaping, and is for exactly these kind of purposes.

Let me know if either of those work for you.

[As a side-note, the ordering of the indices and tensors is still deterministic (will be the same each run), to allow things like compiling computational graphs, but is not guaranteed from version to version.]

sebulo commented 1 year ago

Thanks @jcmgray! Both solutions 1 and 2 fixed my problem.