PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.17k stars 568 forks source link

Adding an is_abtract check in pauli_decompose so it does not break under jit (either jax.jit or catalyst.qjit) #5878

Closed paul0403 closed 1 week ago

paul0403 commented 1 week ago

Before submitting

Please complete the following checklist when submitting a PR:

When all the above are checked, delete everything above the dashed line and fill in the pull request template.


Context: pauli_decompose currently fails under a jax.jit and catalyst.qjit

Description of the Change: An is_abstract check is added to avoid jit trying to compare an isallclose during compile time on an abstract tracer value

Benefits:

Possible Drawbacks:

Related GitHub Issues: closes #5817

[sc-65302] [sc-65344]

github-actions[bot] commented 1 week ago

Hello. You may have forgotten to update the changelog! Please edit doc/releases/changelog-dev.md with:

astralcai commented 1 week ago

A side effect of this change:

>>> x = jnp.array(np.diag([0, 0, 1]))
>>> coeffs, unitaries = qml.pauli_decompose(x, check_hermitian=False).terms()
>>> print(coeffs, unitaries)
[ 0.25+0.j  0.25+0.j -0.25+0.j -0.25+0.j] [I(0) @ I(1), I(0) @ Z(1), Z(0) @ I(1), Z(0) @ Z(1)]

and after jitting,

>>> coeffs, unitaries = jax.jit(functools.partial(qml.pauli_decompose, check_hermitian=False))(x).terms()
>>> print(coeffs, unitaries)
[ 0.25+0.j  0.  +0.j  0.  +0.j  0.25+0.j  0.  +0.j  0.  +0.j  0.  +0.j
  0.  +0.j  0.  +0.j  0.  +0.j -0.  +0.j  0.  +0.j -0.25+0.j  0.  +0.j
  0.  +0.j -0.25+0.j] [I(0) @ I(1), I(0) @ X(1), I(0) @ Y(1), I(0) @ Z(1), X(0) @ I(1), X(0) @ X(1), X(0) @ Y(1), X(0) @ Z(1), Y(0) @ I(1), Y(0) @ X(1), Y(0) @ Y(1), Y(0) @ Z(1), Z(0) @ I(1), Z(0) @ X(1), Z(0) @ Y(1), Z(0) @ Z(1)]

but I guess since jit has to work with statically shaped arrays this is unavoidable.

paul0403 commented 1 week ago

A side effect of this change:

>>> x = jnp.array(np.diag([0, 0, 1]))
>>> coeffs, unitaries = qml.pauli_decompose(x, check_hermitian=False).terms()
>>> print(coeffs, unitaries)
[ 0.25+0.j  0.25+0.j -0.25+0.j -0.25+0.j] [I(0) @ I(1), I(0) @ Z(1), Z(0) @ I(1), Z(0) @ Z(1)]

and after jitting,

>>> coeffs, unitaries = jax.jit(functools.partial(qml.pauli_decompose, check_hermitian=False))(x).terms()
>>> print(coeffs, unitaries)
[ 0.25+0.j  0.  +0.j  0.  +0.j  0.25+0.j  0.  +0.j  0.  +0.j  0.  +0.j
  0.  +0.j  0.  +0.j  0.  +0.j -0.  +0.j  0.  +0.j -0.25+0.j  0.  +0.j
  0.  +0.j -0.25+0.j] [I(0) @ I(1), I(0) @ X(1), I(0) @ Y(1), I(0) @ Z(1), X(0) @ I(1), X(0) @ X(1), X(0) @ Y(1), X(0) @ Z(1), Y(0) @ I(1), Y(0) @ X(1), Y(0) @ Y(1), Y(0) @ Z(1), Z(0) @ I(1), Z(0) @ X(1), Z(0) @ Y(1), Z(0) @ Z(1)]

but I guess since jit has to work with statically shaped arrays this is unavoidable.

Yes, jit would need to know the return shape in advance when compiling (especially in catalyst where everything will be lowered to mlir, which is strongly typed and shaped). There isn't any real workaround for this.

This change does not happen outside jitting right @astralcai ? i.e. pure core PL behavior stays the same? Because if it only occurs for jitting then I think we could safely assume any user of jit would be familiar with such jit-gotcha business

astralcai commented 1 week ago

This change does not happen outside jitting right @astralcai ? i.e. pure core PL behavior stays the same? Because if it only occurs for jitting then I think we could safely assume any user of jit would be familiar with such jit-gotcha business

Yes, this wouldn't affect pure PL behaviour.

obliviateandsurrender commented 1 week ago

@astralcai would users be able to use qml.simplify to get rid of the zero components in that case? For some reason, I couldn't 🤔

astralcai commented 1 week ago

@astralcai would users be able to use qml.simplify to get rid of the zero components in that case? For some reason, I couldn't 🤔

qml.simplify does not remove terms with 0 coefficients.

obliviateandsurrender commented 1 week ago

Don't forget to add a changelog entry!

codecov[bot] commented 1 week ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.67%. Comparing base (ab8b6d5) to head (fe8eff8). Report is 4 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #5878 +/- ## ========================================== - Coverage 99.68% 99.67% -0.01% ========================================== Files 421 421 Lines 40499 40306 -193 ========================================== - Hits 40370 40174 -196 - Misses 129 132 +3 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.