PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
138 stars 35 forks source link

[Frontend] Support scalars of type complex? #76

Closed sergei-mironov closed 1 year ago

sergei-mironov commented 1 year ago

The following simple program doesn't work, if the type is jnp.complex128, but works with float64. Should we support complex scalars as well?

from catalyst import qjit
from jax.numpy import array, complex128, float64

@qjit
def main():
    return array(0, dtype=complex128) # float64 works

main()

The error is

Traceback (most recent call last):
  File "/workspace/src/synthesis/issue1.py", line 10, in <module>
    main()
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 598, in __call__
    return self.compiled_function(*args, **kwargs)
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 444, in __call__
    result = CompiledFunction._exec(
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 332, in _exec
    retval = CompiledFunction.return_value_ptr_to_numpy(result) if result else None
  File "/workspace/modules/catalyst/frontend/catalyst/compilation_pipelines.py", line 297, in return_value_ptr_to_numpy
    jax_array = jax.numpy.asarray(numpy_array)
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 2036, in asarray
    return array(a, dtype=dtype, copy=False, order=order)  # type: ignore
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1982, in array
    dtype = dtypes._lattice_result_type(*leaves)[0]
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 469, in _lattice_result_type
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 469, in <genexpr>
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 312, in _dtype_and_weaktype
    return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
  File "/home/grwlf/.local/lib/python3.10/site-packages/jax/_src/dtypes.py", line 464, in dtype
    raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
TypeError: Value '(0., 0.)' with dtype [('real', '<f8'), ('imag', '<f8')] is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
erick-xanadu commented 1 year ago

Yes. There is a different path for scalars and arrays. I believe this patch should solve the issue.

diff --git a/frontend/catalyst/compilation_pipelines.py b/frontend/catalyst/compilation_pipelines.py
index 5e0bcdb..060b2d6 100644
--- a/frontend/catalyst/compilation_pipelines.py
+++ b/frontend/catalyst/compilation_pipelines.py
@@ -31,7 +31,7 @@ from jax.interpreters.mlir import ir

 import pennylane as qml

-from mlir_quantum.runtime import get_ranked_memref_descriptor, ranked_memref_to_numpy
+from mlir_quantum.runtime import get_ranked_memref_descriptor, ranked_memref_to_numpy, to_numpy

 import catalyst.jax_tracer as tracer
 from catalyst import compiler
@@ -261,7 +261,7 @@ class CompiledFunction:
             a numpy array with the contents of the ranked memref descriptor
         """
         assert not hasattr(ranked_memref, "shape")
-        return np.array(ranked_memref.aligned.contents)
+        return to_numpy(np.array(ranked_memref.aligned.contents))

     @staticmethod
     def ranked_memref_to_numpy(memref_desc):

Although maybe submitting a patch to MLIR to allow casting memrefs from scalars to numpy scalar arrays would be a good idea to avoid this special handling?