PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns #1016

Open mehrdad2m opened 4 weeks ago

mehrdad2m commented 4 weeks ago

Context

when using qml.counts() in outputof a qunatum circuit with qjit, the out put pytree is modified to replace the output pytree element related to qml.countswith tree_structure(("keys", "counts")). However this tansformation is buggy and while it works in simple case, it mistransforms it in more complex patterns.

A example that works fine:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

The result is as expected which is:

PyTreeDef({'1': (*, *)})

In the following, there are two patterns that result in wrong output pytree:

1.

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

results in:

PyTreeDef(((*, *), {'2': *}))

The expected pytree is:

PyTreeDef(({'1': (*, *)}, {'2': *}))

2.

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

results in:

PyTreeDef(([{'1': *}, {'2': *}], (*, *)))

The expected pytree is:

PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))

josh146 commented 3 weeks ago

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

mehrdad2m commented 3 weeks ago

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

Hi @josh146, It is pretty straightforward. The fix should be done in trace_quantum_measurements which is where the output pytree is modified. Basically the simple version of the problem is to write a function replace_child_tree(tree, i, subtree) which recieved a pytree and would replace the ith node of tree that is visited in a DFS with sub_tree. The only tricky part is working with pytrees :)