TeamGraphix / graphix

measurement-based quantum computing (MBQC) compiler and simulator
https://graphix.readthedocs.io
Apache License 2.0
55 stars 20 forks source link

[Bug]: Partial trace of state vecrtor backend #73

Open masa10-f opened 1 year ago

masa10-f commented 1 year ago

Describe the bug

The partial trace of a state vector is generally expected to return a density matrix, but Statevec.ptrace currently returns a state vector. This is an incorrect operation for non-separable quantum states, such as a Bell state.

To Reproduce

You can check the above behavior with the following code.

# %%
import numpy as np
from graphix.sim.statevec import Statevec

# %%
# prepare H and CNOT gates.
H = np.array([[1, 1], [1, -1]]) / np.sqrt(2)
CNOT = np.array([[1, 0, 0, 0], [0, 1, 0, 0],
                 [0, 0, 0, 1], [0, 0, 1, 0]])

# %%
# make a bell state
sv = Statevec(plus_states=False, nqubit=2)
sv.evolve(H, [0])
sv.evolve(CNOT, [0, 1])

print(sv.flatten())
# %%
# trace out 2nd qubit
sv.ptrace([1])
print(sv.flatten())
# the return should be |0><0| + |1><1| but, it returns |0> or |1>

Expected behavior

The Statevec.ptrace should return a density matrix. A code converting a reduced density matrix into a state vector should check its purity before conversion.

Environment (please complete the following information):

Additional context

N/A

shinich1 commented 1 year ago

Just to clarify: this indeed is something to be taken care of, but pattern.simulate_pattern(backend='statevector') is working perfectly fine as it is, because the measured qubits are separable from the rest.

masa10-f commented 1 year ago

I've written a draft improvement plan for partial tracing a separable state below.

 # %%
import numpy as np
from graphix.sim.statevec import Statevec, meas_op

import time

def truncate_one_qubit(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[2**(n - k) * i + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken

# %%
# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0
start = time.perf_counter()
reduced = truncate_one_qubit(k, statevec)
end = time.perf_counter()
print("time(new method)", end - start)
# print(reduced)

# %%
# reference
start = time.perf_counter()
statevec.ptrace([k])
end = time.perf_counter()
print("time(ptrace)", end - start)
# print(statevec.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))

In my environment, the execution time has improved by 3 orders of magnitude when the number of qubits(nqubit) is equal to 10.

masa10-f commented 1 year ago

The above method is not yet complete because the k-th qubit can be in the |0> state(and, not the Statevec format). However, there are several ways to resolve this problem. Once this is completed, I'd like to know the performance comparison with the previous method.

nabe98 commented 12 months ago

This is completed by the follwoings.

def truncate_one_qubit(self, qarg):
    """truncate one qubit

    Args:
        qarg (int): qubit index
    """
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = self.psi.take(indices=0, axis=qarg)
    self.psi = psi if psi[(0,) * psi.ndim] != 0.0 else self.psi.take(indices=1, axis=qarg)
    self.normalize()

The performance can be compared by the followings.

# %%
import numpy as np
from graphix.sim.statevec import Statevec, meas_op

import time

def truncate_one_qubit_old(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[i * 2**(n - k) + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken

def truncate_one_qubit(k, sv):
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = sv.psi.take(indices=0, axis=k)
    psi = psi if psi[(0,) * psi.ndim] != 0.0 else sv.take(indices=1, axis=k).psi
    norm = psi.flatten().dot(psi.flatten().conjugate())
    psi = psi / norm**0.5
    return psi.flatten()

# %%
# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0 (old)
start = time.perf_counter()
reduced = truncate_one_qubit_old(k, statevec)
end = time.perf_counter()
print("time(old method)", end - start)
# print(reduced)

# %%
# discard qubit 0 (new)

start = time.perf_counter()
reduced2 = truncate_one_qubit(k, statevec)
end = time.perf_counter()
print("time(new method)", end - start)
# print(reduced2)

# %%
# reference
start = time.perf_counter()
statevec.ptrace([k])
end = time.perf_counter()
print("time(ptrace)", end - start)
# print(statevec.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))
inner_product = reduced2.dot(statevec.flatten().conjugate())
print(np.abs(inner_product))

In my environment, the new version sometimes slower than the old trancate_one_qubit in this sample. But when I changed the implementation in the statevec to this new trancate_one_qubit, it was nearly 2 times higher than the old ones.

shinich1 commented 12 months ago

@nabe98 thanks - could you paste a plot of speed comparison, for visual inspection? for example, could you compare pattern simulation speed for varying pattern size with and without the new code?

also what do you mean by below? is it sometimes slower somehow?

the new version sometimes slower than the old trancate_one_qubit in this sample

masa10-f commented 12 months ago

@nabe98 Thank you! @shinich1 I compared the performance of the three methods. Including standard deviations, @nabe98 's method is faster than mine.

comparison_truncation comparison_truncation_old_new

# %%
import numpy as np
import matplotlib.pyplot as plt
from statistics import stdev, mean
from graphix.sim.statevec import Statevec, meas_op

import time
from copy import deepcopy

def truncate_one_qubit_old(k, sv):
    n = len(sv.dims())
    taken = np.zeros(2**(n - 1), dtype=np.complex128)
    state = sv.flatten()
    for i in range(2**k):
        for j in range(2**(n - k - 1)):
            taken[i * 2**(n - k - 1) + j] = state[i * 2**(n - k) + j]

    norm = taken.dot(taken.flatten().conjugate())
    taken = taken / norm**0.5
    return taken

def truncate_one_qubit(k, sv):
    # extract |***0_{qarg}***> components if not zero else |***1_{qarg}***>
    psi = sv.psi.take(indices=0, axis=k)
    psi = psi if psi[(
        0,) * psi.ndim] != 0.0 else sv.take(indices=1, axis=k).psi
    norm = psi.flatten().dot(psi.flatten().conjugate())
    psi = psi / norm**0.5
    return psi.flatten()

# %%
time_old = []
time_new = []
time_ptrace = []
iteration = 30

# prepare a non-separable state
k = 3
n = 10
statevec = Statevec(nqubit=n)
for i in range(n):
    statevec.entangle((i, (i + 1) % n))
# print(statevec.flatten())

# %%
# measure qubit k
m_op = meas_op(np.pi / 5)
statevec.evolve(m_op, [k])
# print(statevec.flatten())

# %%
# discard qubit 0 (old)
for i in range(iteration):
    start = time.perf_counter()
    reduced = truncate_one_qubit_old(k, statevec)
    end = time.perf_counter()
    time_old.append(end - start)
    # print("time(old method)", end - start)
    # print(reduced)

# %%
# discard qubit 0 (new)
for i in range(iteration):
    start = time.perf_counter()
    reduced2 = truncate_one_qubit(k, statevec)
    end = time.perf_counter()
    time_new.append(end - start)
    # print("time(new method)", end - start)
    # print(reduced2)

# %%
# reference
for i in range(iteration):
    statevec_cp = deepcopy(statevec)
    start = time.perf_counter()
    statevec_cp.ptrace([k])
    end = time.perf_counter()
    time_ptrace.append(end - start)
    # print("time(ptrace)", end - start)
    # print(statevec_cp.flatten())

# %%
# check inner product
inner_product = reduced.dot(statevec_cp.flatten().conjugate())
print(np.abs(inner_product))
inner_product = reduced2.dot(statevec_cp.flatten().conjugate())
print(np.abs(inner_product))
# %%
# acquire statistics

mean_old = mean(time_old)
mean_new = mean(time_new)
mean_ptrace = mean(time_ptrace)

std_old = stdev(time_old)
std_new = stdev(time_new)
std_ptrace = stdev(time_ptrace)

# %%
# plot

plt.bar(["old", "new", "ptrace"], [mean_old, mean_new, mean_ptrace],
        yerr=[std_old, std_new, std_ptrace])
plt.ylabel("time (s)")
plt.yscale("log")
plt.show()
# %%

# plot without ptrace
plt.bar(["old", "new"], [mean_old, mean_new], yerr=[std_old, std_new])
plt.ylabel("time (s)")
plt.show()
masa10-f commented 12 months ago

To make a comparison in the pattern simulator, we need to modify the 'measure' method which is tailored for the ptrace method, tracing out a group of measured qubits together. With the new truncation method, we can individually trace out measured qubits. This will improve the performance.