google / TensorNetwork

A library for easy and efficient manipulation of tensor networks.
Apache License 2.0
1.82k stars 358 forks source link

Jax eigsh_lanczos silently gives NaN when num_krylov_vecs > n #771

Open alewis opened 4 years ago

alewis commented 4 years ago

This in particular happens when num_krylov_vecs is not supplied and the function is applied to too small an operator

mganahl commented 4 years ago

I'll have a look tonight

chaserileyroberts commented 4 years ago

Indexing outside of the shape of the array will throw NaNs silently in JAX, that's likely the issue.

alewis commented 4 years ago

I'd recommend we put something like num_krylov_vecs = min(num_krylov_vecs, n-1) somewhere far enough down the execution path that n-1 is available.

mganahl commented 4 years ago

Sounds good, though I‘d like to investigate why the function doesnt stop once it hits an invariant subspace

On Thu, Aug 6, 2020 at 4:06 PM Adam Lewis notifications@github.com wrote:

I'd recommend we put something like num_krylov_vecs = min(num_krylov_vecs, n-1) somewhere far enough down the execution path that n-1 is available.

— You are receiving this because you were assigned.

Reply to this email directly, view it on GitHub https://github.com/google/TensorNetwork/issues/771#issuecomment-670165804, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE7RWE4M463VR7BHBAYJY7DR7MEK5ANCNFSM4PW44UMA .

mganahl commented 4 years ago

@alewis, can you provide me with a minimal example?

mganahl commented 4 years ago

The following example works fine for me:

import tensornetwork as tn
import numpy as np
import jax
from jax import config
config.update('jax_enable_x64', True)
be = tn.backends.backend_factory.get_backend('jax')
def matvec(x, mat):
    return mat @ x
D=10
H = np.random.rand(D,D)
H += H.T
init = np.random.rand(D)
be.eigsh_lanczos(A=matvec, args=[H], initial_state=jax.numpy.array(init),numeig=3, reorthogonalize=True, num_krylov_vecs=100,tol=1E-10)
chaserileyroberts commented 4 years ago

Let me try and solve this one. I haven't had to do a bug fix in a while (Because everything we write is so great!).

alewis commented 4 years ago

It's possible I misdiagnosed this - I'll try and see what's up tomorrow

On Fri, Aug 7, 2020, 01:25 Chase Roberts notifications@github.com wrote:

@alewis https://github.com/alewis I tried the code you wrote but the test is passing. Did I miss anything?

def test_eigsh_lacnan(): backend = jax_backend.JaxBackend() def matvec(x, mat): return mat @ x D=10 np.random.seed(1) H = np.random.rand(D,D).astype(np.float64) H += H.T init = np.random.rand(D).astype(np.float64) eigs, results = backend.eigsh_lanczos( A=matvec, args=[H], initial_state=jax.numpy.array(init), numeig=3, reorthogonalize=True, num_krylov_vecs=100, tol=1E-10) assert not np.isnan(eigs).any() assert not np.isnan(results).any()

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/TensorNetwork/issues/771#issuecomment-670328985, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAINHBUAXYQLEDKWUCWNBKDR7OF25ANCNFSM4PW44UMA .

mganahl commented 4 years ago

Actually, it turns out that eigs (or more accurately _implicitly_restarted_arnoldi) has a bug that causes it to return NaN in certain cases.

alewis commented 4 years ago

I'm still getting it from eigsh_lanczos though - I'm trying to pinpoint exactly when

mganahl commented 4 years ago

can you send me the code?

mganahl commented 4 years ago

or post it here

alewis commented 4 years ago
def test_eigsh_lanczos():
  """
  Compares linalg.krylov.eigsh_lanczos with backend.eigsh_lanczos.
  """
  n = 2
  n_kry = 4
  shape = (n, n)
  dtype = np.float32
  A = jax.numpy.ones(shape, dtype=dtype)
  A = 0.5 * (A + A.T)
  x0 = jax.numpy.ones((n, 1), dtype=dtype)

  def array_matvec(B):
    return A @ B
  be = tensornetwork.backends.backend_factory.get_backend('jax')
  test_result = be.eigsh_lanczos(array_matvec, initial_state=x0,
                                 num_krylov_vecs=n_kry)
  tev, teV = test_result
  assert np.all(np.isfinite(np.ravel(np.array(tev))))

  for t in teV:
    assert np.all(np.isfinite(np.ravel(t)))
alewis commented 4 years ago

Ignore the docstring. Does that code pass for you?

mganahl commented 4 years ago

It doesn't, and I think the reason is that the matrix you are passing is singular

alewis commented 4 years ago

It's actually new information to me that this algorithm is expected to fail for singular matrices. Why does it only fail in Jax?

alewis commented 4 years ago

I just checked and when n = 8, n_kry = n - 1 still gives NaN. So indeed it is probably to do with the matrix and not with the number of Krylov vectors.

alewis commented 4 years ago

What a weird coincidence that implicit Arnoldi had some apparently unrelated bug

mganahl commented 4 years ago

lol, I guess we're lucky

mganahl commented 4 years ago

It's actually new information to me that this algorithm is expected to fail for singular matrices. Why does it only fail in Jax?

I'm not sure actually. Lanczos usually works only well for extremal eigenvalues. If the lowest state is degenerate you get a random vector in the subspace. This needs still to be investigated, I'm just saying that the reason is related to your choice of matrix. Note that as you increase n, that matrix becomes crazy singular, so it's probably not the best choice for testing.