PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Implement a frontend UI to invoke the mlir quantum peephole transformation passes #911

Closed paul0403 closed 3 weeks ago

paul0403 commented 1 month ago

Context: Implement a frontend UI to invoke the mlir quantum peephole transformation passes. Currently there is a remove-chained-self-inverse mlir pass in Catalyst that removes two neighbouring Hadamard gates in a circuit, but there is no way to actually invoke it from the frontend. This PR implements this frontend as a decorator on a qnode. When the catalyst.cancel_inverses decorator is added onto a qnode, the remove-chained-self-inverse mlir pass will be run on that qnode.

Following the decision to split #883 into two PRs, this is the frontend portion.

Description of the Change: We implment the peephole optimization library with the mlir transform dialect https://mlir.llvm.org/docs/Dialects/Transform/, https://mlir.llvm.org/docs/Tutorials/transform/

Briefly, we wish to generate the following mlir during frontend tracing and lowering

module @workflow {

  func.func private @f{ ... }
  func.func private @g{ ... }

  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) {
      %0 = transform.apply_registered_pass "remove-chained-self-inverse" to %arg0 {options = "func-name=f"} : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module">
      transform.yield 
    }
  }
}

The outer module is the payload (the module to run the transform on), and the inner module is the transformer (the module that schedules what passes to run). The schedule in this example is run the -remove-chained-self-inverse pass on the payload module @workflow, with pass option func-name=f.

The transform.named_sequence must be terminated by a transform.yield.

A new pass, ApplyTransformSequence.cpp, will perform the following:

  1. Remove the transformer module from the top-level payload module (which is its parent), and save the transformer module operation in memory
  2. Perform the transform, through API provided by the mlir transform dialect
  3. Delete the transformer module

We generate this mlir through two new jax primitives.

The first one, transform_named_sequence_p, will be lowered to transform.named_sequence with a transform.yield inside and put it in a parent transformer module marked with the unit attribute transform.with_named_sequence. The parent transformer module is inserted into the top-level payload module.

The second one, apply_registered_pass_p, will be lowered to a transform.apply_registered_pass inside the transform.named_sequence. If one transform.apply_registered_pass already exists in the sequence, the new pass will be added to after the previous one.

The design is as follows:

  1. QJIT start

  2. Capture jaxpr (QJIT.capture()) 1.1. Before capture, we insert the transform_named_sequence_p to every jaxpr produced (the purpose is so that transform dialect will not fail) 1.2. During capture, if the API @cancel_inverses is called, inject a corresponding apply_registered_pass_p into jaxpr

  3. Generate mlir and compile (QJIT.generate_ir(), QJIT.compile()) 2.1. The jaxpr primitives are lowered to mlir. 2.2. Here the Compiler class will run the pipeline, with the new -apply-transform-sequence pass

Benefits: The user can now run quantum mlir passes from the frontend.

Possible Drawbacks:

Related GitHub Issues:

[sc-67519]

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 92.40506% with 6 lines in your changes missing coverage. Please review.

Project coverage is 97.89%. Comparing base (7231b5b) to head (6b0ad2e).

Files Patch % Lines
frontend/catalyst/jax_primitives.py 90.32% 3 Missing and 3 partials :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #911 +/- ## ========================================== - Coverage 97.93% 97.89% -0.05% ========================================== Files 73 74 +1 Lines 10338 10417 +79 Branches 1170 1182 +12 ========================================== + Hits 10125 10198 +73 - Misses 170 173 +3 - Partials 43 46 +3 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

paul0403 commented 1 month ago

TODO: changelog, top-level API documentation of catalyst.cancel_inverses

paul0403 commented 1 month ago

Note on codecov failures: this is because lit test are not covered by codecov.

Currently these tests exist:

I believe these tests provide enough usage coverage

paul0403 commented 1 month ago

Since the jax update #931 , which updated LLVM, the transform dialect enabled a new check that the payload module cannot be an ancestor of the transformer module.

A new commit conforms to this new requirement.

New design:

paul0403 commented 3 weeks ago

Looks great! I would be ok with this being merged (with the exception of the walkers not exiting early and skipping nested operations if possible).

(bookkeeping comment) The walkers are now exited early when the desired result is seen.