Closed sergei-mironov closed 1 year ago
Yes. There is a different path for scalars and arrays. I believe this patch should solve the issue.
diff --git a/frontend/catalyst/compilation_pipelines.py b/frontend/catalyst/compilation_pipelines.py
index 5e0bcdb..060b2d6 100644
--- a/frontend/catalyst/compilation_pipelines.py
+++ b/frontend/catalyst/compilation_pipelines.py
@@ -31,7 +31,7 @@ from jax.interpreters.mlir import ir
import pennylane as qml
-from mlir_quantum.runtime import get_ranked_memref_descriptor, ranked_memref_to_numpy
+from mlir_quantum.runtime import get_ranked_memref_descriptor, ranked_memref_to_numpy, to_numpy
import catalyst.jax_tracer as tracer
from catalyst import compiler
@@ -261,7 +261,7 @@ class CompiledFunction:
a numpy array with the contents of the ranked memref descriptor
"""
assert not hasattr(ranked_memref, "shape")
- return np.array(ranked_memref.aligned.contents)
+ return to_numpy(np.array(ranked_memref.aligned.contents))
@staticmethod
def ranked_memref_to_numpy(memref_desc):
Although maybe submitting a patch to MLIR to allow casting memrefs from scalars to numpy scalar arrays would be a good idea to avoid this special handling?
The following simple program doesn't work, if the type is
jnp.complex128
, but works withfloat64
. Should we support complex scalars as well?The error is