PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
138 stars 35 forks source link

Print function ordering #407

Open erick-xanadu opened 11 months ago

erick-xanadu commented 11 months ago

The print function should not be re-ordered. Example:

import pennylane as qml
from catalyst import debug, grad, measure, qjit

@qml.qnode(qml.device("lightning.qubit", wires=20))
def f(x):
    qml.RX(x, wires=[0])
    debug.print("hello")
    return qml.state()

@qjit(keep_intermediate=True, verbose=True)
def foo(x):
    y = f(x)
    debug.print("world")
    return y

print(foo(1.0))

Produces the following incorrect MLIR

module @foo {
  func.func public @jit_foo(%arg0: tensor<f64>) -> tensor<1048576xcomplex<f64>> attributes {llvm.emit_c_interface} {
    %0 = call @f(%arg0) : (tensor<f64>) -> tensor<1048576xcomplex<f64>>
    "catalyst.print"() {const_val = "world"} : () -> ()
    return %0 : tensor<1048576xcomplex<f64>>
  }
  func.func private @f(%arg0: tensor<f64>) -> tensor<1048576xcomplex<f64>> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
    "catalyst.print"() {const_val = "hello"} : () -> ()
    quantum.device["/home/erick.ochoalopez/Code/cataliist/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False}"]
    %0 = stablehlo.constant dense<20> : tensor<i64>

Notice that the first statement of the qnode is the print statement as opposed to any other.

The correct MLIR would have the print statement after the RX operation.

dime10 commented 10 months ago

This was a concern I raised when we updated the split tracing procedure (quantum/classical). @grwlf thought he would be able to maintain the ordering, do you know what might be going wrong?

josh146 commented 10 months ago

Ah, this may actually be quite an important bug to address

dime10 commented 10 months ago

Ah, this may actually be quite an important bug to address

@josh146 why do you think so? I agree we should make every effort so that it prints in the "right location", but once you have compiler transforms it's not always well-defined what that means. I think it's also less problematic than the null-termination issue which could cause crashes.

josh146 commented 10 months ago

Oh, it depends on the details of this bug --- if the print statements are being reordered anywhere in the QNode, that would be quite an important bug (which is how I interpreted it), since the print statement may be incorrect (e.g., it will be printing a variable before it is defined, its value is changed, etc.)

However, if the print statement is only being reordered around operations that don't affect the print statement, this is less urgent

dime10 commented 10 months ago

Oh right, I'm pretty sure print statements which print program values are always ordered after the definition of that value, otherwise the IR would be in a broken state. The problem here is mainly applicable to print statements on strings or Python objects, whose order is not strictly defined in a dataflow based representation.

Now you could argue that the print statement order should still be preserved even for strings, in order to perform the elaborate (😉) debugging procedure of placing print statements throughout the program to determine at which instruction the program is failing, but I'm not sure we want to guarantee this behaviour.

josh146 commented 10 months ago

Thanks @dime10! The context helps -- definitely lower priority, potentially does not need to be fixed at all in some cases (since the distinction of 'when' a gate is applied compare to classical processing, if they are not dependent, is quite an arbitrary one)

erick-xanadu commented 10 months ago

I agree that this is low priority, but I think it is important that the print statements are correctly ordered. Gates might raise exceptions at runtime and printing is a side-effect. In python it is clear that the following print statement shouldn't happen:

raise RuntimeException()

print("Hello") # Hello is not printed.

and in some other compiled languages the statements that might create side effects are not re-ordered. Java for example would not reorder the print statement in a similar case to ours.

Our compiler now reorders the print statement such that for the following code

import pennylane as qml
from catalyst import debug, grad, measure, qjit

@qml.qnode(qml.device("lightning.qubit", wires=2))
def f(x, y):
    qml.CNOT(wires=[x, y]) # Can create a runtime exceptions
    debug.print("hello")
    return qml.state()

@qjit()
def foo(x : int, y : int):
    return f(x, y)

print(foo.mlir)

The print statement comes before the quantum gate.

  func.func private @f(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<4xcomplex<f64>> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
    "catalyst.print"() {const_val = "hello\00"} : () -> ()  # <---- here

   // ...snip...
    %3:2 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit
   // ...snip...
    return %9 : tensor<4xcomplex<f64>>
  }
dime10 commented 4 months ago

This could be revisited after the recent changes we've made. I believe the original runtime print statement which acts on program values would be okay to reorder now, whereas the debug print via callback should preserve all side effects.

@erick-xanadu can you double check that this is still an issue with the callback print?