Open paul0403 opened 1 week ago
Regarding testing, there are 3 places that need to be tested throughout the stack:
Given that qml.device(shots=non_int_literal_expression)
is not yet supported at the frontend, step 1 cannot be performed in full. I propose the following:
sample_p.bind(...)
, i.e. bind primitives directly and do not go through pennylane frontend. This tests that expected uses of the bind()
functions produce expected jaxprs. replace_ir
. This leaves the tracing step, i.e. how we actually bind the primitives in jax_tracers.py
during qjit, untested; however, until PL device supports dynamic shots, I don't see a workaround for this.
This is still WIP, but since this is a very big change I want CI to tell me how many things are breaking.
I will remove do-not-merge label once ready.
Attention: Patch coverage is 46.80851%
with 25 lines
in your changes missing coverage. Please review.
Project coverage is 80.84%. Comparing base (
05a1e98
) to head (a75bc93
).
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Context: As part of the work to make pennylane circuits accept dynamic device parameters, this PR allows the
sample
andcounts
operations across catalyst to work with a dynamic number of shots.Based on #1170 , credit @rauletorresc [sc-74736] [sc-78842]
Benefits:
sample
andcounts
can be used for qnodes with a dynamicshots
value in the device, when such a device becomes possible in pennylane.Description of the Change:
0. Overview of changes
SampleOp
andCountsOp
in mlir no longer take in the shots attribute.DeviceInitOp
now takes in an argument forshots
.This is because integer attributes are tied to concrete literal values and hence must be static.
All other changes are in service to the above.
1. Changes in runtime
1.1. Unify all measurement ops' interface in runtime CAPI to not have shots In runtime CAPI, among all the measurement APIs, only
__catalyst__qis__Sample(...)
and__catalyst__qis__Counts(...)
take in a requiredint64_t shots
argument. Other measurements (state, expval, etc) do not require this shots argument, though they also use shots in their semantics. This has caused some confusion for other external device implementors in the past.Noticing that all other measurements simply get shots from the device (through the
device_shots
field of the specific catalystQuantumDevice
class, e.g. here in lightning.qubit, and here where it's used by probs), we do the same for sample and counts.Their
shots
argument was removed, and in their definition, they now retrive the device shots throughQuantumDevice::GetDeviceShots()
.1.2 Set the device shots upon device creation in CAPI An extra argument,
int64_t shots
, was added to__catalyst__rt__device_init()
.It gets propagated through a few helper functions, and eventually reaches the
getOrCreateDevice()
in theExecutionContext
. Here, at device creation, we immediatelySetDeviceShots()
with the provided shots argument in the device init CAPI.Coorespondingly, the device shots is now no longer parsed from the device init op's attribute dictionary string. This is because as an attribute, the shots could only be a concrete integer literal and could not be dynamic, so shots should no longer be an attribute in the device attributes dictionary.
TODO: we can only remove the shots attribute parsing in openqasm. The lightning devices are in the lightning repo and need to be removed there.
In frontend, when registering the device kwargs into device init op attributes, we no longer add device shots to the attrbiutes dict. This is the change in
frontend/catalyst/device/qjit_device.py
.1.3 Impacted tests In runtime tests, for all manually created devices, shots creation is changed from attr dict to an argument in
__catalyst__rt__device_init(..., shots)
orSetDeviceShots(shots)
(whichever more appropriate).All shots arguments in sample and counts are removed.
2. Changes in mlir
2.1 Changes in operation definitions
DeviceInitOp
now takes in an I64 argument forshots
. This means shots can be any SSA value, and is in principle dynamic!SampleOp
andCountsOp
no longer takes in theshots
attribute. Correspondingly, the verifications regarding shots (inmlir/lib/Quantum/IR/QuantumOps.cpp
) need to be removed.2.2 Changes in mlir pipeline 2.2.1 Conversion to CAPI calls In
--convert-quantum-to-llvm
(ConversionPatterns.cpp
), the call to device init CAPI should now call with the shots argument, and the call to sample and counts should now call without the shots argument.2.2.2 Bufferization In
--quantum-bufferize
(BufferizationPatterns.cpp
), since return shape of sample is (shots, number_of_qubits), which is now dynamic in shots, the memref allocation should now take in the dynamic shot value. The shot SSA value must be retrived from the device init op in the qnode function. Note that counts is unaffected as its return shape do not relate to shots.Note that the return shape of the sample operation can be either static or dynamic, depending on whether the frontend shots was static or dynamic. For static shots, bufferization should still memref alloc without any dynamic size arguments, and just allocate the result shape.
2.3 ZNE changes In ZNE, the folded circuit will need to copy a device init operation. The ZNE pass does so by manually creating a new device init operation in the folded circuit (with
rewriter.create(...)
). This means the shots SSA value need to be cloned into the folded circuit as well.2.4 Impacted tests Trivally changed to adhere to the above changes.
3. Changes in frontend
3.1. Device shots initiation As mentioned, the device shots is no longer kept as an attribute for the backend to parse.
3.2 Primitive definitions and lowering 3.2.1 Device Init The device init primitive now take in shots as an argument. Since the mlir operation expects
I64
as shots, and frontend tracing will trace into stablehlo's "scalar tensor"s, aTensorExtractOp
needs to be inserted.All
bind
s should call with shots as a positional argument, so that it is a jaxpr primitive SSA value argument, and hence gets lowered into an SSA value in mlir.3.2.2 Sample and Counts Instead of taking in
shots
andshape
, they now take inshots
andnum_qubits
.shots
is no longer converted to an attribute on the lowered mlir operation (since the mlir operations do not have the attributes anymore).This PR only deals with making shots dynamic, so
num_qubits
is still just an integer. This means result shape ofcounts
is still static.Two cases are possible here for sample:
shots
, aka just an integer parameter of the primitive. In this case we give the lowered mlir sample operation a static shape (akatensor<5x1xf64>
).shots
, aka an SSA value argument of the primitive. In this case we give the lowered mlir sample operation a dynamic shape in the shots dimension (akatensor<?x1xf64>
).Note that in jaxpr's abstract_eval,
DShapedArray
can have both static or dynamic values as its shape, so we use that as sample primitive's abstract_eval.3.3 Tracing For sample and counts, when
bind()
-ing a primtiive to the jaxpr during tracing, static shots need to be called as a keyword argument in thebind()
methods, and dynamic shots need to be called as a positional argument. This is so that static shots become a jaxpr primitive (integer) parameter, and dynamic shots become a jaxpr primitive argument SSA value. This happens injax_tracers.py/trace_quantum_measurements
.When tracing device init primitives, if the pennylane device shots is
None
, we set the shots to 0; otherwise we follow the pennylane device's shots integer. Note that PL device shots is now always static (i.e. eitherNone
orint_literal_value
), but we still bind with positional argument, as if it were dynamic, so we are ready when PL device has dynamic shots.The above tracing changes are also added to
frontend/catalyst/from_plxpr.py
for plxpy support.3.4 Testing Regarding testing, there are 3 places that need to be tested throughout the stack:
Given that
qml.device(shots=non_int_literal_expression)
is not yet supported at the frontend, step 1 cannot be performed in full. I propose the following:[lit test] Manually write functions with
sample_p.bind(...)
, i.e. bind primitives directly and do not go through pennylane frontend. This tests that expected uses of thebind()
functions produce expected jaxprs. [lit test] Manually lower these jaxprs to mlir. [pytest] Manually send a textual IR, with dynamic shaped sample ops, into the backend through replace_ir.This leaves the tracing step, i.e. how we actually bind the primitives in jax_tracers.py during qjit, untested for dynamic shapes (static shapes is still tested by all the existing tests that use
sample
); however, until PL device supports dynamic shots, I don't see a workaround for this.