PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Recycling the same `device.so` by different `quantum.set_state` causes a crash #1044

Closed paul0403 closed 1 week ago

paul0403 commented 3 weeks ago

The following mlir has multiple functions using set_state on the same device shared object file and it crashes:

module @circuit_pytest {
  func.func public @jit_circuit_pytest(%arg0: tensor<f64>, %arg1: tensor<3xf64>, %arg2: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %cst = arith.constant dense<[(1.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<4xcomplex<f64>>
    %0 = call @circuit_pytest(%cst, %arg0, %arg1, %arg2) : (tensor<4xcomplex<f64>>, tensor<f64>, tensor<3xf64>, tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
  }

  func.func @circuit_pytest_tape_2(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    %3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    %extracted_slice = tensor.extract_slice %arg1[2] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    %out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    %6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
  }

  func.func @circuit_pytest_tape_1(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    %3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    %4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    %6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %extracted_slice = tensor.extract_slice %arg1[1] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    %out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
  }

  func.func @circuit_pytest_tape_0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    %3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    %4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    %6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %extracted_slice = tensor.extract_slice %arg1[0] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    %out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
  }

  func.func private @circuit_pytest(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<f64>, %arg2: tensor<3xf64>, %arg3: tensor<f64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
    %0 = call @circuit_pytest_tape_0(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    %1 = call @circuit_pytest_tape_1(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    %2 = call @circuit_pytest_tape_2(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    return %0 : tensor<f64>
  }

  func.func @setup() {
    quantum.init
    return
  }
  func.func @teardown() {
    quantum.finalize
    return
  }
}
malloc(): unaligned tcache chunk detected

However, if we do one of the two following, then there won't be a crash:

  1. create copies of the device by cp librtd_lightning.so librtd_lightning_1.so && cp librtd_lightning.so librtd_lightning_2.so, and use these copies in quantum.device
  2. remove all the quantum.set_state, i.e.

    module @circuit_pytest {
    func.func public @jit_circuit_pytest(%arg0: tensor<f64>, %arg1: tensor<3xf64>, %arg2: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} {
    %cst = arith.constant dense<[(1.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<4xcomplex<f64>>
    %0 = call @circuit_pytest(%cst, %arg0, %arg1, %arg2) : (tensor<4xcomplex<f64>>, tensor<f64>, tensor<3xf64>, tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
    }
    
    func.func @circuit_pytest_tape_2(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    //%3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    %extracted_slice = tensor.extract_slice %arg1[2] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    //%out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %out_qubits = quantum.custom "RX"(%extracted) %2 : !quantum.bit
    //%4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %4 = quantum.namedobs %1[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    //%6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %6 = quantum.insert %0[ 0], %1 : !quantum.reg, !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
    }
    
    func.func @circuit_pytest_tape_1(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    //%3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    //%4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %4 = quantum.namedobs %1[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    //%6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %6 = quantum.insert %0[ 0], %1 : !quantum.reg, !quantum.bit
    %extracted_slice = tensor.extract_slice %arg1[1] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    //%out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %out_qubits = quantum.custom "RX"(%extracted) %2 : !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
    }
    
    func.func @circuit_pytest_tape_0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<3xf64>) -> tensor<f64> {
    quantum.device["/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/utils/../../../runtime/build/lib/librtd_lightning.so", "LightningSimulator", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
    %0 = quantum.alloc( 2) : !quantum.reg
    %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit
    %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
    //%3:2 = quantum.set_state(%arg0) %1, %2 : (tensor<4xcomplex<f64>>, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit)
    //%4 = quantum.namedobs %3#0[ PauliZ] : !quantum.obs
    %4 = quantum.namedobs %1[ PauliZ] : !quantum.obs
    %5 = quantum.expval %4 : f64
    %from_elements = tensor.from_elements %5 : tensor<f64>
    //%6 = quantum.insert %0[ 0], %3#0 : !quantum.reg, !quantum.bit
    %6 = quantum.insert %0[ 0], %1 : !quantum.reg, !quantum.bit
    %extracted_slice = tensor.extract_slice %arg1[0] [1] [1] : tensor<3xf64> to tensor<1xf64>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xf64> into tensor<f64>
    %extracted = tensor.extract %collapsed[] : tensor<f64>
    //%out_qubits = quantum.custom "RX"(%extracted) %3#1 : !quantum.bit
    %out_qubits = quantum.custom "RX"(%extracted) %2 : !quantum.bit
    %7 = quantum.insert %6[ 1], %out_qubits : !quantum.reg, !quantum.bit
    quantum.dealloc %7 : !quantum.reg
    quantum.device_release
    return %from_elements : tensor<f64>
    }
    
    func.func private @circuit_pytest(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<f64>, %arg2: tensor<3xf64>, %arg3: tensor<f64>) -> tensor<f64> attributes {diff_method = "parameter-shift", llvm.linkage = #llvm.linkage<internal>, qnode} {
    %0 = call @circuit_pytest_tape_0(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    %1 = call @circuit_pytest_tape_1(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    %2 = call @circuit_pytest_tape_2(%arg0, %arg2) : (tensor<4xcomplex<f64>>, tensor<3xf64>) -> tensor<f64>
    return %0 : tensor<f64>
    }
    func.func @setup() {
    quantum.init
    return
    }
    func.func @teardown() {
    quantum.finalize
    return
    }
    }
paul0403 commented 3 weeks ago

A minimal frontend example to reproduce this error (as of main branch https://github.com/PennyLaneAI/catalyst/commit/b405fbbfc68c955a9e16263e6245b3c732a2de83):

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(dev)
def f():
    qml.StatePrep(
                np.array([complex(1, 0), complex(0, 0), complex(0, 0), complex(0, 0)]),
                wires=[0, 1],
            )
    return qml.probs()

@qml.qnode(dev)
def g():
    qml.StatePrep(
                np.array([complex(1, 0), complex(0, 0), complex(0, 0), complex(0, 0)]),
                wires=[0, 1],
            )
    return qml.probs()

@qjit
def main():
    return f(), g()

main()
malloc(): unaligned tcache chunk detected

Quick note: this is an mlir issue and not a jax issue so I know this wouldn't help, but I checked it anyway, and swapping out jax.numpy.array for np.array results in the same error.