PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
110 stars 28 forks source link

Find a better way to link blas routines from jaxlib #753

Open paul0403 opened 2 months ago

paul0403 commented 2 months ago

Currently catalyst links to the blas library shipped with scipy, not jaxlib, because jaxlib does not ship their blas library with a .so shared object file.

This results in when calling numerical functions from jaxlib, for example those in jax.scipy.linalg, most functions will call undefined symbols. This is now being fixed by manually adding in the routines in frontend/catalyst/utils/libcustom_calls.cpp on a case by case basis, but this is not efficient. For example,

@qml.qjit
def func(x):
    res = jax.scipy.linalg.expm(x)
    return res

y = jnp.array([[1, 0], [0, 1]])
x = func(y)

>>>   [[2.71828183 0.        ]
 [0.         2.71828183]] 

but

@qml.qjit
def func(x):
    res = jax.scipy.linalg.sqrtm(x)
    return res

y = jnp.array([[1, 0], [0, 1]])
x = func(y)

>>> Traceback (most recent call last):
  File "/home/paul.wang/small_playgrounds_dump/expmfix.py", line 56, in <module>
    x = func(y)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 110, in __call__
    requires_promotion = self.jit_compile(args)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 171, in jit_compile
    self.compiled_function, self.qir = self.compile()
  File "/home/paul.wang/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 278, in compile
    compiled_fn = CompiledFunction(shared_object, func_name, restype, self.compile_options)
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 132, in __init__
    self.shared_object = SharedObjectManager(shared_object_file, func_name)
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 61, in __init__
    self.open()
  File "/home/paul.wang/catalyst/frontend/catalyst/compiled_functions.py", line 65, in open
    self.shared_object = ctypes.CDLL(self.shared_object_file)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /tmp/funcwym85xsl/func.so: undefined symbol: lapack_zgees

Since the blas routines required for expm were added in #752, but those for sqrtm are still missing.

See details in #752

mlxd commented 2 months ago

Since this is under discussion too for Lightning, the use of https://pypi.org/project/scipy-openblas64/ may be a good candidate to pull in and use throughout, as it is the default used within scipy already (despite the binary included and hashed in the wheel). Happy to discuss this next week.

dime10 commented 2 months ago

Since this is under discussion too for Lightning, the use of https://pypi.org/project/scipy-openblas64/ may be a good candidate to pull in and use throughout, as it is the default used within scipy already (despite the binary included and hashed in the wheel). Happy to discuss this next week.

Just a note that our particular issue is due to jax generating code that expects some small custom wrappers around lapack/blas functions. Since we receive jax generated code for linalg functions, we also have to use these custom wrappers rather than interfacing with lapack functions directly.

This issue is about avoiding including (slightly modified versions) of these wrappers in our package.