PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.27k stars 585 forks source link

[BUG] `ConcretizationTypeError` when unnecessary `work_wires` is specified. #4481

Open erick-xanadu opened 1 year ago

erick-xanadu commented 1 year ago

Expected behavior

Hi,

I was testing another issue that required work_wires when I found that there are some quantum programs that will not be able to be jax.jited when using work_wires even if work_wires is not required. Please note that the example submitted, when we remove the unneeded keyword argument work_wires=[7] it succeeds in being jax.jited.

Actual behavior

ConcretizationTypeError is raised.

Additional information

No response

Source code

import pennylane as qml
import jax

@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires=8))
def circuit(x : int):
    op = qml.Identity(wires=[0])
    op2 = qml.ctrl(op, control=[x], work_wires=[7])
    qml.matrix(op2)
    return qml.state()

print(circuit(1))

Tracebacks

Traceback (most recent call last):
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 266, in fn
    return self.tape_fn(obj.expand(), *args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 301, in tape_fn
    return self._tape_fn(obj, *args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 134, in _matrix
    raise qml.operation.MatrixUndefinedError
pennylane.operation.MatrixUndefinedError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 12, in <module>
    print(circuit(1))
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 975, in __call__
    self.construct(args, kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 872, in construct
    self._tape = make_qscript(self.func, shots)(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/tape/qscript.py", line 1492, in wrapper
    result = fn(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 9, in circuit
    qml.matrix(op2)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 213, in __call__
    return self._create_wrapper(obj, *targs, **tkwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 412, in _create_wrapper
    wrapper = self.fn(obj, *targs, **tkwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 275, in fn
    raise e1 from e
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 260, in fn
    return self._fn(obj, *args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 127, in matrix
    return op.matrix(wire_order=wire_order)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/op_math/controlled.py", line 433, in matrix
    return qml.math.expand_matrix(
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in expand_matrix
    wire_indices = [wire_order.index(wire) for wire in wires]
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in <listcomp>
    wire_indices = [wire_order.index(wire) for wire in wires]
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/core.py", line 667, in __bool__
    def __bool__(self): return self.aval._bool(self)
  File "/home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages/jax/_src/core.py", line 1370, in error
    raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function. 
The error occurred while tracing the function circuit at /home/erick.ochoalopez/Code/catalyst-latest/test.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 12, in <module>
    print(circuit(1))
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 975, in __call__
    self.construct(args, kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/qnode.py", line 872, in construct
    self._tape = make_qscript(self.func, shots)(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/tape/qscript.py", line 1492, in wrapper
    result = fn(*args, **kwargs)
  File "/home/erick.ochoalopez/Code/catalyst-latest/test.py", line 9, in circuit
    qml.matrix(op2)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 213, in __call__
    return self._create_wrapper(obj, *targs, **tkwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 412, in _create_wrapper
    wrapper = self.fn(obj, *targs, **tkwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 275, in fn
    raise e1 from e
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/transforms/op_transforms.py", line 260, in fn
    return self._fn(obj, *args, **kwargs)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/functions/matrix.py", line 127, in matrix
    return op.matrix(wire_order=wire_order)
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/ops/op_math/controlled.py", line 433, in matrix
    return qml.math.expand_matrix(
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in expand_matrix
    wire_indices = [wire_order.index(wire) for wire in wires]
  File "/home/erick.ochoalopez/Code/pennylane/pennylane/math/matrix_manipulation.py", line 134, in <listcomp>
    wire_indices = [wire_order.index(wire) for wire in wires]
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[].
The problem arose with the `bool` function. 
The error occurred while tracing the function circuit at /home/erick.ochoalopez/Code/catalyst-latest/test.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

System information

>>> qml.about()
Name: PennyLane
Version: 0.32.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/erick.ochoalopez/Code/catalyst-latest/env/lib/python3.10/site-packages
Editable project location: /home/erick.ochoalopez/Code/pennylane
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-catalyst, PennyLane-Lightning

Existing GitHub issues

albi3ro commented 1 year ago

@erick-xanadu Shouldn't we be treating wire labels are static metadata/ compile time constant?

trbromley commented 1 year ago

@erick-xanadu Shouldn't we be treating wire labels are static metadata/ compile time constant?

@albi3ro what would the motivation be for this? What sets your intuition on dynamic variables vs static constants?

albi3ro commented 1 year ago

Previously we've thought about "things that potentially trainable" like any TensorLike as the dynamic variables. That assumption is baked into both how we write things and how we test things.

When we allow a variable to be abstract, we strongly limit the number of things we can do with it. We can no longer use it with control flow.

I also don't think we have a single test of for abstract wires.

I would be open to allowing to wires to be dynamic, but we would need time to adjust the assumptions in our code, add tests, and work through all the problems that will inevitably come up.

josh146 commented 2 weeks ago

Hey all, just revisiting this issue now (with newer context we might have from plxpr work).

Will plxpr be treating wires as dynamic or static?

isaacdevlugt commented 1 week ago

This does work with qjit btw :)