jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

solve and triangular_solve fail to return Inf for batches of singular matrices on CPU #3589

Closed dpfau closed 3 months ago

dpfau commented 4 years ago

This is an issue with the CPU backend when running triangular_solve. If I provide a single low-rank matrix to triangular_solve (i.e. an upper triangular matrix with zeros on the diagonal) it will solve correctly by back-substitution until it reaches the zero, and return Inf or NaN for all columns after that. If I run a batch of matrices through triangular_solve, however, it will return zero in the later columns (and then a nonsense result in the last column) instead of Inf/NaN. This error is propagated through to jnp.linalg.solve, as can be seen in the following example:

from jax import lax_linalg
import numpy as np
import jax.numpy as jnp

n = 12
k = [0, 3, 5, 0, 1]  # corank of each batch element
u = np.triu(np.random.rand(len(k), n, n) + 1)
for i in range(len(k)):
  if k[i] > 0:
    u[i, -k[i]:] = 0

foo = lax_linalg.triangular_solve(
    u, np.ones((len(k), 1, n)),
    left_side=False, transpose_a=False, lower=False)
bar = jnp.linalg.solve(u.transpose((0, 2, 1)), np.ones((len(k), n, 1))).transpose((0, 2, 1))

for i in range(len(k)):
  print(lax_linalg.triangular_solve(
    u[i], np.ones((1, n)),
    left_side=False, transpose_a=False, lower=False))
  print('')
  print(foo[i])
  print('')
  print(bar[i])
  print('\n\n')

The first result (triangular_solve on a single matrix) is correct, while following results (triangular_solve and solve on a batch of matrices) are incorrect. On GPU, this bug is not present.

hawkinsp commented 4 years ago

For the batched case the implementation switches to a completely different algorithm (actually, the same implementation used on TPU): https://github.com/google/jax/blob/7b57dc8c8043163a5e649ba66143ccef880d7d58/jax/lax_linalg.py#L436 For the batch 1 case, we call LAPACK TRSM which acts as you say.

The batched case calls into XLA, which uses an algorithm inspired by MAGMA that inverts diagonal blocks: https://github.com/tensorflow/tensorflow/blob/bd006c354f11f9045d344f3e48b47be9f8368dac/tensorflow/compiler/xla/service/triangular_solve_expander.cc#L439

hawkinsp commented 4 years ago

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, scipy.linalg.solve_triangular would raise a singular matrix exception in the corresponding situation.

dpfau commented 4 years ago

I am working on a fix so that the gradient of the LU decomposition returns the correct value even if the input matrix is singular, and I use the columns of Infs/NaNs to identify what the rank of the matrix is, so in this case it actually is important that it returns Inf/NaN instead of zeros or an error.

On Mon, Jun 29, 2020 at 5:18 PM Peter Hawkins notifications@github.com wrote:

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, scipy.linalg.solve_triangular would raise a singular matrix exception in the corresponding situation.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3589#issuecomment-651222554, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACFPOTFPWUAKKWOXU53RZC5GTANCNFSM4OKUPODQ .

dpfau commented 4 years ago

This same trick is currently used by the gradient np.linalg.det when dealing with singular matrices, which leads me to believe that it may also fail in this case.

On Mon, Jun 29, 2020 at 5:35 PM David Pfau david.pfau@gmail.com wrote:

I am working on a fix so that the gradient of the LU decomposition returns the correct value even if the input matrix is singular, and I use the columns of Infs/NaNs to identify what the rank of the matrix is, so in this case it actually is important that it returns Inf/NaN instead of zeros or an error.

On Mon, Jun 29, 2020 at 5:18 PM Peter Hawkins notifications@github.com wrote:

I'm wondering if you care about the values returned if the matrix is singular, or whether you would be happy to get, say, a matrix full of NaNs out for that batch element. Note that, say, scipy.linalg.solve_triangular would raise a singular matrix exception in the corresponding situation.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3589#issuecomment-651222554, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACFPOTFPWUAKKWOXU53RZC5GTANCNFSM4OKUPODQ .

hawkinsp commented 4 years ago

Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix?

(I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.)

dpfau commented 4 years ago

I'm worried that would miss Inf/NaN due to underflow for values that aren't identically zero. Different backends might have different tolerances and I want to make sure I catch everything. It's possible I'm being too cautious though.

I definitely don't want it to raise an error. I want this to work on singular matrices - triangular_solve still gives useful results for all columns up to the rank of the matrix.

On Mon, Jun 29, 2020 at 6:02 PM Peter Hawkins notifications@github.com wrote:

Wouldn't it suffice to look for the first 0 on the diagonal of the input matrix?

(I'm also just trying to understand what API contract you expect, because the usual contract says "this is an illegal input". It's possible we can make the XLA algorithm mimic the behavior of the usual TRSM algorithm in this case, but it's not clear to me the behavior in the singular case is actually well defined without also fixing the choice of algorithm.)

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3589#issuecomment-651244943, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACHBFF2SP6WQN2OQIATRZDCJVANCNFSM4OKUPODQ .

hawkinsp commented 3 months ago

As of the current state, you'll get an output containing NaNs if you pass a singular matrix. They will not necessarily appear starting at the relevant column: that depends on the algorithm choice, and because some of the algorithms involve matrix multiplication they will have the effect of smearing nans across the output if any are present.