The previous PR (https://github.com/pasqal-io/horqrux/pull/27) implements the parameter shift rule (PSR) for parameters defined in the values argument of expectations. However, it suffered from some limitations:
It did not allow for the PSR for every parameter: parameters had to be defined via values, and a user couldn't pass a parameter directly into a gate.
This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.
Some noteworthy points:
When jitting functions containing checkify.check points, the output type of the original function is changed to (error, output_of_original_function). This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised (https://github.com/pasqal-io/horqrux/issues/30).
Previously, the param attribute of a Parametric gate could be of type str | float. This was problematic when implementing custom JVP rules, since a float is a valid jax type, but a string is not (e.g. https://github.com/google/jax/issues/3045). Consequently, param has been explicitly split into param_name: str and param_val: float, so that param_val is always a valid jax type.
The previous PR (https://github.com/pasqal-io/horqrux/pull/27) implements the parameter shift rule (PSR) for parameters defined in the
values
argument of expectations. However, it suffered from some limitations:values
, and a user couldn't pass a parameter directly into a gate.This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.
Some noteworthy points:
checkify.check
points, the output type of the original function is changed to(error, output_of_original_function)
. This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised (https://github.com/pasqal-io/horqrux/issues/30).param
attribute of aParametric
gate could be of typestr | float
. This was problematic when implementing custom JVP rules, since afloat
is a valid jax type, but a string is not (e.g. https://github.com/google/jax/issues/3045). Consequently,param
has been explicitly split intoparam_name: str
andparam_val: float
, so thatparam_val
is always a valid jax type.Closes #29