jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Improve support for using JAX's custom LAPACK calls externally #23172

Open joeycarter opened 1 month ago

joeycarter commented 1 month ago

Hi there,

This request is a follow-up to the discussion here: https://github.com/google/jax/discussions/18065.

What we're trying to accomplish

Suppose we have our own JIT compiler, called qjit. We would like to be able to do the following, for example:

import jax
import jax.numpy as jnp
import qjit

@qjit
def matrix_sqrt(A):
    return jax.scipy.linalg.sqrtm(A)

X = jnp.array([[1., 2., 3.],
               [2., 4., 2.],
               [3., 2., 1.]])
matrix_sqrt(X)

and similarly for the other functions in jax.scipy.linalg. However, when we do so, we get undefined symbol errors, in this case:

OSError: /tmp/matrix_sqrtonhryf12/matrix_sqrt.so: undefined symbol: lapack_zgees

Current workaround

What we currently do to get around this is what was suggested in https://github.com/google/jax/discussions/18065: to manually compile and link in the required custom JAX LAPACK modules under jaxlib/cpu/, to define the required symbols such as lapack_zgees.

However, this is cumbersome and difficult to maintain (suppose these modules change in a future JAX release).

What else we've tried

We noticed that the jaxlib package comes shipped with the shared object file jaxlib/cpu/_lapack.so, which contains the symbols for the kernel functions that the custom JAX wrappers use. For instance, using the nm tool, we can find the corresponding kernel function for lapack_zgees, ComplexGees<std::complex<double>>::Kernel:

$ nm -C _lapack.so | grep "ComplexGees<std::complex<double> >::Kernel"
000000000000d840 t jax::ComplexGees<std::complex<double> >::Kernel(void*, void**, XlaCustomCallStatus_*)

or, in its mangled form:

$ nm  _lapack.so | grep "000000000000d840"                          
000000000000d840 t _ZN3jax11ComplexGeesISt7complexIdEE6KernelEPvPS4_P20XlaCustomCallStatus_

We tried loading this symbol using the dynamic linking loader as follows (simplified for brevity):

struct XlaCustomCallStatus_ {};

void* handle = dlopen(".../jaxlib/cpu/_lapack.so", RTLD_LAZY);

std::string symbol = "_ZN3jax11ComplexGeesISt7complexIdEE6KernelEPvPS4_P20XlaCustomCallStatus_";
typedef void (*Kernel_t)(void* out_tuple, void** data, XlaCustomCallStatus_*);
Kernel_t _dgetrf = (Kernel_t) dlsym(handle, symbol.c_str());

However, dlsym fails to find the symbol. We believe this is because this function is not exported, as denoted by the lowercase t in the nm output (where exported functions are conventionally denoted by uppercase letters, e.g. T).

So, we believe that we've hit a dead-end with this approach.

Possible solutions

Given the fact that these symbols are already shipped with jaxlib in jaxlib/cpu/_lapack.so, would it be possible to make these functions globally available in a future JAX release to make it possible to dynamically load them using dlopen and dlsym? If that is not possible, is there another approach that is more amenable to using these custom LAPACK calls than manually building and linking them in ourselves?

dfm commented 1 month ago

Thanks for this request! This is something we'd love to support. We don't have a specific timeline, but I wanted to just confirm here that this feature request is acknowledged.

dfm commented 1 month ago

Adding one clarification: jaxlib doesn't actually come with its own LAPACK library. It actually links to the one from scipy. The code used to populate our API with those symbols is here:

https://github.com/google/jax/blob/b56ed8eeddc5794f3981832a38b6bcc195eb20f8/jaxlib/cpu/lapack.cc#L40-L152

It's probably worth noting that as part of this conversation!

joeycarter commented 1 month ago

Hi @dfm,

Could you help clarify something? We were testing a few of the jax.linalg functions with the LAPACK wrappers mentioned above and noticed in certain cases we get incorrect results when calling some JAX linear algebra function from within our qjit-compiled block. For example, with jax.scipy.linalg.lu:

import jax
import jax.numpy as jnp
import qjit

A = jnp.array([[1., 2., 3.],
               [5., 4., 2.],
               [3., 2., 1.]])

P, L, U = jax.scipy.linalg.lu(A)

@qjit
def qjit_lu(X):
    return jax.scipy.linalg.lu(X)

P_qjit, L_qjit, U_qjit = qjit_lu(A)

assert jnp.allclose(P @ L @ U, A)  # Passes
assert jnp.allclose(P, P_qjit)     # Fails
assert jnp.allclose(L, L_qjit)     # Fails
assert jnp.allclose(U, U_qjit)     # Fails

We noticed, for example, that in this case, some of the pivot matrix rows have been reordered:

>>> print(P)
[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]
>>> print(P_qjit)
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]

To get to this point we've used the jax::Getrf::Kernel function:

template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;

template <typename T>
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
  int b = *(reinterpret_cast<int32_t*>(data[0]));
  int m = *(reinterpret_cast<int32_t*>(data[1]));
  int n = *(reinterpret_cast<int32_t*>(data[2]));
  const T* a_in = reinterpret_cast<T*>(data[3]);

  void** out = reinterpret_cast<void**>(out_tuple);
  T* a_out = reinterpret_cast<T*>(out[0]);
  int* ipiv = reinterpret_cast<int*>(out[1]);
  int* info = reinterpret_cast<int*>(out[2]);
  if (a_out != a_in) {
    std::memcpy(a_out, a_in,
                static_cast<int64_t>(b) * static_cast<int64_t>(m) *
                    static_cast<int64_t>(n) * sizeof(T));
  }
  for (int i = 0; i < b; ++i) {
    fn(&m, &n, a_out, &m, ipiv, info);
    a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
    ipiv += std::min(m, n);
    ++info;
  }
}

template struct Getrf<float>;
template struct Getrf<double>;
template struct Getrf<std::complex<float>>;
template struct Getrf<std::complex<double>>;

We noticed that a number of FFI kernels were added recently, e.g. jax::LuDecomposition::Kernel. Should we be using these kernels instead? We used the other functions because they don't depend on the XLA FFI libraries.

joeycarter commented 1 month ago

We figured out the issue with the JIT-compiled block giving a different answer: our JIT compiler sends the input to the LAPACK call in C-ordered (row-major) format, but if I've understood correctly, the scipy LAPACK calls expect FORTRAN-ordered (column-major) format.

Sorry for the noise! I am still curious whether we should be using the old kernel function or their FFI variants, though.

dfm commented 1 month ago

Glad you got that figured out! Yeah, our plan is to migrate all the custom calls to the FFI in the near future (see https://github.com/google/jax/issues/23056), so in the long run, that's what you'll need to target. Unfortunately we're currently in that awkward compatibility period where both exist in parallel, and the FFI kernels don't all exist yet!

joeycarter commented 1 month ago

Great, thanks for the clarification!

While we're on the subject, I'm curious how jax handles the row-/column-major issue. Is there a transformation that occurs somewhere before the call to the LAPACK routine that ensures the array is in column-major format? If so could you point me to where in the code that happens?

dfm commented 1 month ago

Sure! The place where this is specified on the JAX side is via the operand_layouts and result_layouts parameters in the custom call lowering. For a n dimensional input, we pass: (n - 2, n - 1, n - 3, n - 4, ..., 0) as the layout to specify column-major (instead of (n - 1, n - 2, ...) for row-major). For example:

https://github.com/google/jax/blob/530ed026b8926cba3cb3d06c855b516fd4c9fb38/jaxlib/gpu_solver.py#L112

I haven't gone spelunking to find out where exactly this is used in XLA, but I'm sure it would be possible to track!