jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

jax.lax.linalg.lu returns LU factorisation for singular matrix, jax.scipy.linalg.lu_solve silently returns spurious result in 32bit #23626

Open johannahaffner opened 2 weeks ago

johannahaffner commented 2 weeks ago

Description

Singular matrix with linearly dependent columns have infinitely many LU factorisations. However, the lax implementation of LU factorisation will return a factorisation regardless. This will then fail to produce a valid result in jax.scipy.linalg.lu_solve, but the result will still be finite unless the calculation is performed in double-precision, and no error is raised.

I'm accessing this through a downstream library (lineax), if this cannot (or maybe should not) be fixed in JAX, then I could work around this by checking if $Ax = b$.

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.lax.linalg as lax_linalg

DOUBLE = False
if DOUBLE:
    jax.config.update("jax_enable_x64", True)

A = jnp.array([[1., 2.],   
               [2., 4.]])  
b = jnp.array([3., 6.])

a_lu, a_pivots, a_permutation = lax_linalg.lu(A)  
a_x = jsp.linalg.lu_solve((a_lu, a_pivots), b)

C = jnp.array([[ 8.,  2., -1., -1., -1.], 
               [ 2.,  2.,  1., -1.,  0.], 
               [ 1., -1.,  0.,  0.,  0.], 
               [ 1.,  1.,  0.,  0.,  0.], 
               [ 1.,  0.,  0.,  0.,  0.]])
d = jnp.array([2., 3., 0., 4., 3.])

c_lu, c_pivots, c_permutation = lax_linalg.lu(C)  
c_x = jsp.linalg.lu_solve((c_lu, c_pivots), d)

a_x, c_x

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.31 jaxlib: 0.4.31 numpy: 2.1.0 python: 3.12.4 (v3.12.4:8e8a4baf65, Jun 6 2024, 17:33:18) [Clang 13.0.0 (clang-1300.0.29.30)] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='REDACTED', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:16:46 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T8112', machine='arm64')

jakevdp commented 2 weeks ago

Hi - thanks for the report! This is the expected output in jax.scipy.linalg. One of the restrictions on JAX APIs that enables compilation and execution at scale is that Python-level control flow (such as raising LinalgError) cannot depend on the values within arrays (this is similar to what is discussed in Sharp bits: out of bound indexing).

What that means is that in places where scipy might raise a LinalgError on singular input, JAX will instead finish the execution and return matrices containing NaN values. In your own code, depending on the context we'd suggest either doing the same (let your code execute but return invalid values for invalid inputs), or if you're outside the context of a transformation like JIT, you can explicitly check for NaN values to see if the previous computation was successful.

What do you think?

johannahaffner commented 2 weeks ago

Hi Jake!

Thanks for the quick reply, and the thorough explanation. This does make a lot of sense.

However, the issue I ran into is that lu_solve does not always return NaN values. It can return something that looks like a solution (all real values, for instance), but does not solve the problem.

I checked this for some other singular matrices too. It does indeed frequently return NaN values, but every now and then it seems to manage to find a "solution" through some lucky combination of floating-point error accumulation in 32-bit.

If catching this will always require evaluating $Ax=b$, then I suggest we simply add a warning to the documentation of lu, lu_solve, and lu_factor.

johannahaffner commented 1 week ago

~Taking a look at the implementations, I noticed that the JAX implementation does not use partial pivoting, while the scipy version does. I added a half sentence to the docstrings of the relevant functions to specify this.~

I also checked other linear algebra operations for the singular 5x5 matrix in the example above. I get a determinant and inverse for this matrix in 32-bit too. I think both are related to numerical stability issues, and this would probably happen with other singular matrices that would be nonsingular if just one or more entries contained a nonzero value.

Therefore, checking for NaN/inf values does not always seem to be enough to catch all errata. I propose that we mention this briefly in the relevant Sharp Bits section on divergences from Numpy. (E.g. ~no partial pivoting / numerical stability can differ~, no value-dependent error handling, in case of strange values check for divergences between 32-bit and 64-bit, and look at other tests.)

What do you think?

hawkinsp commented 1 week ago

I don't follow the first part of this. The JAX implementation does use partial pivoting.

johannahaffner commented 1 week ago

You're right, I mixed something up. Sorry for the confusion!

johannahaffner commented 1 week ago

I continued poking at this, I still think there is something there (even if it is not yesterday's fever dream).

If I transpose a singular matrix, I get a finite result for a linear solve using LU quite often, in 78 % of cases below. In these cases, checking for NaN/inf values alone is not enough, $Ax=b$ does have to be evaluated to check if the solution is valid.

Here is code that reproduces this:

import jax.numpy as jnp
import jax.random as jr

import jax.scipy.linalg as jsl

template = jnp.array([[1, 1, 1, 1, 1], 
                      [0, 1, 1, 1, 1], 
                      [0, 0, 0, 1, 1],  # 3rd column linearly dependent on 1st and 2nd
                      [0, 0, 0, 1, 1], 
                      [0, 0, 0, 0, 1]])

def check_lu_solve(matrix, known_b):
    lu_and_piv = jsl.lu_factor(matrix)
    solved_for_x = jsl.lu_solve(lu_and_piv, known_b)
    return jnp.logical_not(jnp.isfinite(solved_for_x).all())  # Expected: not finite

key = jr.key(1234)
original_matrix_results = []
transposed_matrix_results = []
for _ in range(100):
    key, subkey = jr.split(key)

    # Generate a random singular matrix and known b
    random_matrix = jnp.where(template, jr.normal(subkey, shape=(5, 5)), 0)
    random_x = jr.normal(subkey, shape=(5,))
    known_b = jnp.dot(random_matrix, random_x)

    # Try original, upper triangular matrix
    original_matrix_results.append(check_lu_solve(random_matrix, known_b))

    # Now transpose the matrix
    random_matrix = jnp.transpose(random_matrix)
    known_b = jnp.dot(random_matrix, random_x)

    # Try solve with transposed matrix
    transposed_matrix_results.append(check_lu_solve(random_matrix, known_b))

original_matrix_results = jnp.array(original_matrix_results)
transposed_matrix_results = jnp.array(transposed_matrix_results)

print("Upper triangular matrix, expected result:", jnp.sum(original_matrix_results), "out of 100")
print("Transposed matrix, expected result:", jnp.sum(transposed_matrix_results), "out of 100")

None of the solutions for $x$ are correct. Scipy returns the same values for $x$, but will raise a warning for a singular matrix. Switching to double precision does not make a difference for the above code.