Open joeycarter opened 3 months ago
Thanks for this request! This is something we'd love to support. We don't have a specific timeline, but I wanted to just confirm here that this feature request is acknowledged.
Adding one clarification: jaxlib doesn't actually come with its own LAPACK library. It actually links to the one from scipy. The code used to populate our API with those symbols is here:
It's probably worth noting that as part of this conversation!
Hi @dfm,
Could you help clarify something? We were testing a few of the jax.linalg
functions with the LAPACK wrappers mentioned above and noticed in certain cases we get incorrect results when calling some JAX linear algebra function from within our qjit-compiled block. For example, with jax.scipy.linalg.lu
:
import jax
import jax.numpy as jnp
import qjit
A = jnp.array([[1., 2., 3.],
[5., 4., 2.],
[3., 2., 1.]])
P, L, U = jax.scipy.linalg.lu(A)
@qjit
def qjit_lu(X):
return jax.scipy.linalg.lu(X)
P_qjit, L_qjit, U_qjit = qjit_lu(A)
assert jnp.allclose(P @ L @ U, A) # Passes
assert jnp.allclose(P, P_qjit) # Fails
assert jnp.allclose(L, L_qjit) # Fails
assert jnp.allclose(U, U_qjit) # Fails
We noticed, for example, that in this case, some of the pivot matrix rows have been reordered:
>>> print(P)
[[0. 1. 0.]
[0. 0. 1.]
[1. 0. 0.]]
>>> print(P_qjit)
[[0. 1. 0.]
[1. 0. 0.]
[0. 0. 1.]]
To get to this point we've used the jax::Getrf::Kernel
function:
template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
template <typename T>
void Getrf<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int b = *(reinterpret_cast<int32_t*>(data[0]));
int m = *(reinterpret_cast<int32_t*>(data[1]));
int n = *(reinterpret_cast<int32_t*>(data[2]));
const T* a_in = reinterpret_cast<T*>(data[3]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
int* ipiv = reinterpret_cast<int*>(out[1]);
int* info = reinterpret_cast<int*>(out[2]);
if (a_out != a_in) {
std::memcpy(a_out, a_in,
static_cast<int64_t>(b) * static_cast<int64_t>(m) *
static_cast<int64_t>(n) * sizeof(T));
}
for (int i = 0; i < b; ++i) {
fn(&m, &n, a_out, &m, ipiv, info);
a_out += static_cast<int64_t>(m) * static_cast<int64_t>(n);
ipiv += std::min(m, n);
++info;
}
}
template struct Getrf<float>;
template struct Getrf<double>;
template struct Getrf<std::complex<float>>;
template struct Getrf<std::complex<double>>;
We noticed that a number of FFI kernels were added recently, e.g. jax::LuDecomposition::Kernel
. Should we be using these kernels instead? We used the other functions because they don't depend on the XLA FFI libraries.
We figured out the issue with the JIT-compiled block giving a different answer: our JIT compiler sends the input to the LAPACK call in C-ordered (row-major) format, but if I've understood correctly, the scipy LAPACK calls expect FORTRAN-ordered (column-major) format.
Sorry for the noise! I am still curious whether we should be using the old kernel function or their FFI variants, though.
Glad you got that figured out! Yeah, our plan is to migrate all the custom calls to the FFI in the near future (see https://github.com/google/jax/issues/23056), so in the long run, that's what you'll need to target. Unfortunately we're currently in that awkward compatibility period where both exist in parallel, and the FFI kernels don't all exist yet!
Great, thanks for the clarification!
While we're on the subject, I'm curious how jax handles the row-/column-major issue. Is there a transformation that occurs somewhere before the call to the LAPACK routine that ensures the array is in column-major format? If so could you point me to where in the code that happens?
Sure! The place where this is specified on the JAX side is via the operand_layouts
and result_layouts
parameters in the custom call lowering. For a n
dimensional input, we pass: (n - 2, n - 1, n - 3, n - 4, ..., 0)
as the layout to specify column-major (instead of (n - 1, n - 2, ...)
for row-major). For example:
I haven't gone spelunking to find out where exactly this is used in XLA, but I'm sure it would be possible to track!
Hi there,
This request is a follow-up to the discussion here: https://github.com/google/jax/discussions/18065.
What we're trying to accomplish
Suppose we have our own JIT compiler, called
qjit
. We would like to be able to do the following, for example:and similarly for the other functions in
jax.scipy.linalg
. However, when we do so, we getundefined symbol
errors, in this case:Current workaround
What we currently do to get around this is what was suggested in https://github.com/google/jax/discussions/18065: to manually compile and link in the required custom JAX LAPACK modules under
jaxlib/cpu/
, to define the required symbols such aslapack_zgees
.However, this is cumbersome and difficult to maintain (suppose these modules change in a future JAX release).
What else we've tried
We noticed that the
jaxlib
package comes shipped with the shared object filejaxlib/cpu/_lapack.so
, which contains the symbols for the kernel functions that the custom JAX wrappers use. For instance, using thenm
tool, we can find the corresponding kernel function forlapack_zgees
,ComplexGees<std::complex<double>>::Kernel
:or, in its mangled form:
We tried loading this symbol using the dynamic linking loader as follows (simplified for brevity):
However,
dlsym
fails to find the symbol. We believe this is because this function is not exported, as denoted by the lowercaset
in thenm
output (where exported functions are conventionally denoted by uppercase letters, e.g.T
).So, we believe that we've hit a dead-end with this approach.
Possible solutions
Given the fact that these symbols are already shipped with
jaxlib
injaxlib/cpu/_lapack.so
, would it be possible to make these functions globally available in a future JAX release to make it possible to dynamically load them usingdlopen
anddlsym
? If that is not possible, is there another approach that is more amenable to using these custom LAPACK calls than manually building and linking them in ourselves?