CQCL / tket

Source code for the TKET quantum compiler, Python bindings and utilities
https://tket.quantinuum.com/
Apache License 2.0
249 stars 48 forks source link

`ClassicalExpBox` interacts badly with `flatten_registers` and `DecomposeBoxes` #1544

Open cqc-alec opened 4 weeks ago

cqc-alec commented 4 weeks ago

Apologies for the rather vague title. I have been investigating #1541 and have only a partial understanding of the root causes and still no idea how to solve it. The following Python snippet illustrates two of the problems, which are related (both involve calls to flatten_registers() messing things up):

from pytket.circuit import Bit, CircBox, Circuit, OpType
from pytket.passes import DecomposeBoxes

x_bits = [Bit("x", i) for i in range(2)]
y_bit = Bit("y", 0)
all_bits = x_bits + [y_bit]

def make_base_circ():
    c = Circuit()
    for b in all_bits:
        c.add_bit(b)
    return c

def make_circ():
    c = make_base_circ()
    c.add_classicalexpbox_bit(x_bits[0] ^ x_bits[1], [y_bit])
    return c

def check_circ(c):
    cmds = c.get_commands()
    assert len(cmds) == 1
    op = cmds[0].op
    assert op.type == OpType.ClassicalExpBox
    exp = op.get_exp()
    args = exp.args
    assert args == x_bits

def problem_1():
    c0 = make_circ()
    check_circ(c0)  # OK
    c1 = c0.copy()
    c1.flatten_registers()
    check_circ(c0)  # fails

def problem_2():
    c0 = make_circ()
    check_circ(c0)  # OK
    c1 = make_base_circ()
    cbox = CircBox(c0)
    c1.add_circbox(cbox, all_bits)
    DecomposeBoxes().apply(c1)
    check_circ(c1)  # fails

In problem 1, we see that the circuit c0 -- which should not have been modified at all by calling flatten_registers() on a copy -- has in fact been modified. This is presumably something to do with the convoluted way in which ClassicalExpBox is defined, templated on a Python class, but I don't understand it.

In problem 2, DecomposeBoxes() is called on a circuit containing a CircBox containing a ClassicalExpBox. The code for this calls flatten_registers() on the replacement circuit, leading to a similar problem.

sjdilkes commented 1 week ago

When we copy circuits, we copy an Op_ptr for a vertex -https://github.com/CQCL/tket/blob/2936a367daed02faab6ba73a85c10e1cf0a60606/tket/src/Circuit/macro_manipulation.cpp#L78.

When copying circuits with ClassicalExpBox, this means both the circuits use the same Op_ptr for their corresponding vertices. This Op_ptr contains the templated classical expression object _expr.

When we call flatten_registers (a substep in DecomposeBoxes aswell), the method has additional handling for making sure the ClassicalExpBoxBase object method rename_units is called (https://github.com/CQCL/tket/blob/2936a367daed02faab6ba73a85c10e1cf0a60606/tket/include/tket/Circuit/Circuit.hpp#L1778), which updates the arguments of the _expr object - https://github.com/CQCL/tket/blob/2936a367daed02faab6ba73a85c10e1cf0a60606/tket/include/tket/Circuit/ClassicalExpBox.hpp#L33. This method mutates (which it acknowledges is possible despite the const tag) the _expr held in the Op_ptr being referenced by both circuits.

This additional renaming is done to help with later lowering (I guess) - the ClassicalExpBox handling in pytket-phir looks at the arguments of the _expr object and not the arguments of the Command object it's been recovered from.

An example of a quick fix to the shared python in the github issue is to update problem_1 to:

def problem_1():
    c0 = make_circ()
    check_circ(c0)  # OK
    c1 = c0.copy()
    c1 = Circuit.from_dict(c1.to_dict())
    c1.flatten_registers()
    check_circ(c0)  # fails

When the writing to and from the dictionary occurs, it creates a new Op and Op_ptr for the ClassicalExpBox, meaning the mutation made by flatten_registers doesn't effect the original circuit.

The obvious fix is to update copy_graph to add extra handling for ClassicalExpBox. Unfortunately this requires casting Op_ptr which is a bit of a nightmare - ClassicalExpBoxBase exists to allow a parent class with a rename_units attribute that doesn't require being templated. However, we need the templated object to do the casting. To assign the template properly, we need to define a python object which means we need to reference pybind, as with here https://github.com/CQCL/tket/blob/2936a367daed02faab6ba73a85c10e1cf0a60606/pytket/binders/circuit/classical.cpp#L37.

Which means, the only solution I can see to this problem currently is to link pybind11 to macro_manipulation.cpp, allowing us to cast the Op_ptr appropriately.

Do you have any other ideas?