PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
138 stars 35 forks source link

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns #1016

Open mehrdad2m opened 3 months ago

mehrdad2m commented 3 months ago

Context

When using qml.counts() in the output of a quantum circuit with qjit, the output pytree is modified to replace the output pytree element related to qml.counts with tree_structure(("keys", "counts")). However this transformation is buggy and while it works for simple cases, it incorrectly transforms more complex patterns.

An example that works fine:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1":  qml.counts()}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

The result is as expected:

PyTreeDef({'1': (*, *)})

In the following example, there are two patterns that result in the wrong output pytree:

1.

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(((*, *), {'2': *}))

instead of the expected pytree of:

PyTreeDef(({'1': (*, *)}, {'2': *}))

2.

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, result_tree = tree_flatten(result)
print(result_tree)

results in:

PyTreeDef(([{'1': *}, {'2': *}], (*, *)))

while the expected pytree is:

PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))

A possible solution would update trace_quantum_measurements(), which is where the output pytree is modified. You could write a function replace_child_tree(tree, i, subtree) which receives a pytree and would replace the ith node of the tree that is visited in a DFS of subtree.

josh146 commented 2 months ago

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

mehrdad2m commented 2 months ago

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

Hi @josh146, It is pretty straightforward. The fix should be done in trace_quantum_measurements which is where the output pytree is modified. Basically the simple version of the problem is to write a function replace_child_tree(tree, i, subtree) which recieved a pytree and would replace the ith node of tree that is visited in a DFS with sub_tree. The only tricky part is working with pytrees :)