PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
142 stars 36 forks source link

Allow the `sample` and `counts` operations across catalyst to work with a dynamic number of shots #1310

Open paul0403 opened 1 week ago

paul0403 commented 1 week ago

Context: As part of the work to make pennylane circuits accept dynamic device parameters, this PR allows the sample and counts operations across catalyst to work with a dynamic number of shots.

Based on #1170 , credit @rauletorresc [sc-74736] [sc-78842]

Benefits: sample and counts can be used for qnodes with a dynamic shots value in the device, when such a device becomes possible in pennylane.

Description of the Change:

0. Overview of changes

SampleOp and CountsOp in mlir no longer take in the shots attribute. DeviceInitOp now takes in an argument for shots.

This is because integer attributes are tied to concrete literal values and hence must be static.

All other changes are in service to the above.

1. Changes in runtime

1.1. Unify all measurement ops' interface in runtime CAPI to not have shots In runtime CAPI, among all the measurement APIs, only __catalyst__qis__Sample(...) and __catalyst__qis__Counts(...) take in a required int64_t shots argument. Other measurements (state, expval, etc) do not require this shots argument, though they also use shots in their semantics. This has caused some confusion for other external device implementors in the past.

Noticing that all other measurements simply get shots from the device (through the device_shots field of the specific catalyst QuantumDevice class, e.g. here in lightning.qubit, and here where it's used by probs), we do the same for sample and counts.

Their shots argument was removed, and in their definition, they now retrive the device shots through QuantumDevice::GetDeviceShots().

1.2 Set the device shots upon device creation in CAPI An extra argument, int64_t shots, was added to __catalyst__rt__device_init().

It gets propagated through a few helper functions, and eventually reaches the getOrCreateDevice() in the ExecutionContext. Here, at device creation, we immediately SetDeviceShots() with the provided shots argument in the device init CAPI.

Coorespondingly, the device shots is now no longer parsed from the device init op's attribute dictionary string. This is because as an attribute, the shots could only be a concrete integer literal and could not be dynamic, so shots should no longer be an attribute in the device attributes dictionary.

TODO: we can only remove the shots attribute parsing in openqasm. The lightning devices are in the lightning repo and need to be removed there.

In frontend, when registering the device kwargs into device init op attributes, we no longer add device shots to the attrbiutes dict. This is the change in frontend/catalyst/device/qjit_device.py.

1.3 Impacted tests In runtime tests, for all manually created devices, shots creation is changed from attr dict to an argument in __catalyst__rt__device_init(..., shots) or SetDeviceShots(shots) (whichever more appropriate).

All shots arguments in sample and counts are removed.

2. Changes in mlir

2.1 Changes in operation definitions

2.2 Changes in mlir pipeline 2.2.1 Conversion to CAPI calls In --convert-quantum-to-llvm (ConversionPatterns.cpp), the call to device init CAPI should now call with the shots argument, and the call to sample and counts should now call without the shots argument.

2.2.2 Bufferization In --quantum-bufferize (BufferizationPatterns.cpp), since return shape of sample is (shots, number_of_qubits), which is now dynamic in shots, the memref allocation should now take in the dynamic shot value. The shot SSA value must be retrived from the device init op in the qnode function. Note that counts is unaffected as its return shape do not relate to shots.

Note that the return shape of the sample operation can be either static or dynamic, depending on whether the frontend shots was static or dynamic. For static shots, bufferization should still memref alloc without any dynamic size arguments, and just allocate the result shape.

2.3 ZNE changes In ZNE, the folded circuit will need to copy a device init operation. The ZNE pass does so by manually creating a new device init operation in the folded circuit (with rewriter.create(...)). This means the shots SSA value need to be cloned into the folded circuit as well.

2.4 Impacted tests Trivally changed to adhere to the above changes.

3. Changes in frontend

3.1. Device shots initiation As mentioned, the device shots is no longer kept as an attribute for the backend to parse.

3.2 Primitive definitions and lowering 3.2.1 Device Init The device init primitive now take in shots as an argument. Since the mlir operation expects I64 as shots, and frontend tracing will trace into stablehlo's "scalar tensor"s, a TensorExtractOp needs to be inserted.

All binds should call with shots as a positional argument, so that it is a jaxpr primitive SSA value argument, and hence gets lowered into an SSA value in mlir.

3.2.2 Sample and Counts Instead of taking in shots and shape, they now take in shots and num_qubits. shots is no longer converted to an attribute on the lowered mlir operation (since the mlir operations do not have the attributes anymore).

This PR only deals with making shots dynamic, so num_qubits is still just an integer. This means result shape of counts is still static.

Two cases are possible here for sample:

Note that in jaxpr's abstract_eval, DShapedArray can have both static or dynamic values as its shape, so we use that as sample primitive's abstract_eval.

3.3 Tracing For sample and counts, when bind()-ing a primtiive to the jaxpr during tracing, static shots need to be called as a keyword argument in the bind() methods, and dynamic shots need to be called as a positional argument. This is so that static shots become a jaxpr primitive (integer) parameter, and dynamic shots become a jaxpr primitive argument SSA value. This happens in jax_tracers.py/trace_quantum_measurements.

When tracing device init primitives, if the pennylane device shots is None, we set the shots to 0; otherwise we follow the pennylane device's shots integer. Note that PL device shots is now always static (i.e. either None or int_literal_value), but we still bind with positional argument, as if it were dynamic, so we are ready when PL device has dynamic shots.

The above tracing changes are also added to frontend/catalyst/from_plxpr.py for plxpy support.

3.4 Testing Regarding testing, there are 3 places that need to be tested throughout the stack:

Given that qml.device(shots=non_int_literal_expression) is not yet supported at the frontend, step 1 cannot be performed in full. I propose the following:

[lit test] Manually write functions with sample_p.bind(...), i.e. bind primitives directly and do not go through pennylane frontend. This tests that expected uses of the bind() functions produce expected jaxprs. [lit test] Manually lower these jaxprs to mlir. [pytest] Manually send a textual IR, with dynamic shaped sample ops, into the backend through replace_ir.

This leaves the tracing step, i.e. how we actually bind the primitives in jax_tracers.py during qjit, untested for dynamic shapes (static shapes is still tested by all the existing tests that use sample); however, until PL device supports dynamic shots, I don't see a workaround for this.

paul0403 commented 1 week ago

Regarding testing, there are 3 places that need to be tested throughout the stack:

  1. A valid pennylane frontend qnode gets qjitted into a valid catalyst jaxpr, with shots being a proper argument and shape being either static or dynamic
  2. A valid jaxpr gets lowered into a valid mlir, with sampleop's result shape being either static or dynamic
  3. A mlir with a dynamically shaped sample op can be executed correctly in the backend

Given that qml.device(shots=non_int_literal_expression) is not yet supported at the frontend, step 1 cannot be performed in full. I propose the following:

This leaves the tracing step, i.e. how we actually bind the primitives in jax_tracers.py during qjit, untested; however, until PL device supports dynamic shots, I don't see a workaround for this.

paul0403 commented 1 week ago

This is still WIP, but since this is a very big change I want CI to tell me how many things are breaking.

I will remove do-not-merge label once ready.

codecov[bot] commented 16 hours ago

Codecov Report

Attention: Patch coverage is 46.80851% with 25 lines in your changes missing coverage. Please review.

Project coverage is 80.84%. Comparing base (05a1e98) to head (a75bc93).

Files with missing lines Patch % Lines
runtime/tests/Test_OpenQasmDevice.cpp 0.00% 17 Missing :warning:
runtime/tests/Test_NullQubit.cpp 0.00% 6 Missing :warning:
frontend/catalyst/from_plxpr.py 75.00% 1 Missing and 1 partial :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1310 +/- ## ========================================== - Coverage 80.87% 80.84% -0.03% ========================================== Files 73 73 Lines 8131 8134 +3 Branches 840 841 +1 ========================================== Hits 6576 6576 - Misses 1502 1504 +2 - Partials 53 54 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.