amazon-braket / amazon-braket-pennylane-plugin-python

A plugin for allowing Xanadu PennyLane to use Amazon Braket devices
https://amazon-braket-pennylane-plugin-python.readthedocs.io/
Apache License 2.0
44 stars 38 forks source link

Support for `SProd`, `Prod`, and `Sum` #253

Open astralcai opened 2 months ago

astralcai commented 2 months ago

Sum

Sum should be handled the same way as Hamiltonian and LinearCombination, which was partially addressed in https://github.com/amazon-braket/amazon-braket-pennylane-plugin-python/pull/252, but the same treatment should be applied to translate_result_type and translate_result in translation.py as well.

Note: Sum.ops is deprecated, so instead of measurement.obs.ops, do _, ops = measurement.obs.terms(), and then use ops.

SProd and Prod

Since SProd and Prod could be nested, they are not guaranteed to be single-term observables. For example, an SProd could be 0.1 * (qml.Z(0) + qml.X(1)), in which case it's actually a Sum. Similarly, a Prod could be qml.Z(0) @ (qml.X(0) + qml.Y(1)).

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

Caveat: Prod.terms() will resolve to itself if the Prod only contains one term. For example:

>>> op = qml.X(0) @ qml.Y(1)
>>> op.terms()
([1.0], [X(0) @ Y(1)])

This may result in infinite recursion in _translate_observable, so a base case should be added to return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in H.operands]) if H is a Prod with a single term.

Note: The terms() function will unwrap any nested structures but also simplify the observable. For example:

>>> op = qml.X(0) @ qml.I(1)
>>> op.terms()
([1.0], [X(0)])

This will create a mismatch between the number of targets in the translated observable and the original observable. We do plan on addressing this issue in PennyLane and have terms() recursively unwraps the observable without doing any simplification, but for now, in _pl_to_braket_circuit, do not use circuit.measurements directly, instead do something like

measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

Then use measurements instead of circuit.measurements from this point on. The list of simplified measurements should also be passed into _apply_gradient_result_type and used there.

Device

Now since SProd, Prod, and Sum all could be nested, multi-term observables, they should be removed from the list of supported observables and added back if no shots are present:

@property
def observables(self) -> frozenset[str]:
    base_observables = frozenset(super().observables - {"Prod", "SProd", "Sum"})
    # Amazon Braket only supports coefficients and multiple terms when shots==0
    if not self.shots:
        return base_observables.union({"Hamiltonian", "LinearCombination", "Prod", "SProd", "Sum"})
    return base_observables
math411 commented 2 months ago

Hi @astralcai, thank you for raising this. I shall start looking into a fix.

speller26 commented 1 month ago

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

The current _translate_observable implementations for Sum, SProd and Prod recursively call _translate_observable on their operands:

@_translate_observable.register
def _(t: qml.operation.Tensor):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])

@_translate_observable.register
def _(t: qml.ops.Prod):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])

@_translate_observable.register
def _(t: qml.ops.SProd):
    return t.scalar * _translate_observable(t.base)

Shouldn't that take care of the nesting problem?

speller26 commented 1 month ago
measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

I'm noticing that simplify alters the order of operands (at least in Prod); is this intentional?

astralcai commented 1 month ago

This means that the same treatment for Hamiltonian, LinearCombination and Sum should extend to SProd and Prod as well, including _translate_observable, which should register Sum, SProd and Prod all under the same dispatch function as Hamiltonian, which uses H.terms().

The current _translate_observable implementations for Sum, SProd and Prod recursively call _translate_observable on their operands:

@_translate_observable.register
def _(t: qml.operation.Tensor):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.obs])

@_translate_observable.register
def _(t: qml.ops.Prod):
    return reduce(lambda x, y: x @ y, [_translate_observable(factor) for factor in t.operands])

@_translate_observable.register
def _(t: qml.ops.SProd):
    return t.scalar * _translate_observable(t.base)

Shouldn't that take care of the nesting problem?

It should, but as I recall it didn't. I was looking into it some time ago and couldn't make it work, that's why I suggested using the same approach for all potential multi-term observables. You can give it a try. I don't remember what the issue was exactly, but I believe it has something to do with the braket backend unable to parse scalar products.

/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/interpreter.py:545: in _
    parsed = self.context.parse_pragma(node.command)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/program_context.py:455: in parse_pragma
    return parse_braket_pragma(pragma_body, self.qubit_mapping)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:216: in parse_braket_pragma
    visited = BraketPragmaNodeVisitor(qubit_table).visit(tree)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:34: in visit
    return tree.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:861: in accept
    return visitor.visitBraketPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:14: in visitBraketPragma
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1226: in accept
    return visitor.visitBraketResultPragma(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:39: in visitBraketResultPragma
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1290: in accept
    return visitor.visitResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParserVisitor.py:44: in visitResultType
    return self.visitChildren(ctx)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/antlr4/tree/Tree.py:44: in visitChildren
    childResult = c.accept(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/generated/BraketPragmasParser.py:1867: in accept
    return visitor.visitObservableResultType(self)
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/braket/default_simulator/openqasm/parser/braket_pragmas.py:98: in visitObservableResultType
    observables, targets = self.visit(ctx.observable())
E   TypeError: cannot unpack non-iterable NoneType object
----------------------------- Captured stderr call -----------------------------
line 1:26 mismatched input '0.1' expecting {'x', 'y', 'z', 'i', 'h', 'hermitian'}

This occured when trying to parse the scalar product of an observable. See this run: https://github.com/PennyLaneAI/plugin-test-matrix/actions/runs/9018042395/job/24777766316

astralcai commented 1 month ago
measurements = []
for mp in circuit.measurements:
    obs = mp.obs
    if isinstance(obs, (Hamiltonian, LinearCombination, Sum, SProd, Prod)):
        obs = obs.simplify()
        mp = type(mp)(obs)
   measurements.append(mp)

I'm noticing that simplify alters the order of operands (at least in Prod); is this intentional?

Simplify does not preserve the original order of operands.