google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.61k stars 2.7k forks source link

Very slow JIT compilation due to constant-folding of very large BCOO sparse matrix nonzero entry array #14655

Open aterenin opened 1 year ago

aterenin commented 1 year ago

Description

I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:

import jax
import jax.numpy as jnp
import jax.experimental.sparse as sparse
from jax.experimental.sparse import BCOO

n = 10000000

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

operator = build_sparse_linear_operator()

def fn(x):
    return operator(jnp.ones(n) / x).sum()

fn(1.0) # executes in 0.1s
jax.jit(fn)(1.0) # executes in almost one minute

Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:

Constant folding an instruction is taking > 8s:

  slice.22 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

The problem seems to be that the stored array, nonzeroes has shape (n,2), which in this case is very large, yet the JIT compiler tries to constant-fold it. This seems like a bug, unless there are good reasons why arrays with millions of elements should be constant-folded, in which case it would be very helpful to have some way of telling the compiler not to do so in this case.

What jax/jaxlib version are you using?

v0.4.4

Which accelerator(s) are you using?

N/A

Additional system info

N/A

NVIDIA GPU info

N/A

jakevdp commented 1 year ago

It seems like the compiler isn't making a great choice here with respect to the run-time/compile-time tradeoffs involved in constant folding. If you change your code slightly though, the problematic array will be computed at runtime:

def build_sparse_linear_operator():
    def product(other):
        nonzeroes = sparse.eye(n).indices # shape (n,2)
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product
aterenin commented 1 year ago

It seems like the compiler isn't making a great choice here with respect to the run-time/compile-time tradeoffs involved in constant folding. If you change your code slightly though, the problematic array will be computed at runtime:

def build_sparse_linear_operator():
    def product(other):
        nonzeroes = sparse.eye(n).indices # shape (n,2)
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

Thanks! Unfortunately, while this would avoid the problem here, I can't fix things upstream in that way - the array nonzeroes in my codebase is computed using what is effectively a black-box, CPU-only algorithm outside of JAX. Perhaps a better way to write the MWE would have been to use scipy.sparse.eye or similar instead.

Do you know if there are any workarounds I can implement to prevent the compiler from constant-folding the array?

jakevdp commented 1 year ago

I don't know... it's really an XLA bug, and I'm not sure of a way to change what XLA does here. Maybe you could rewrite your code so that nonzeros is explicitly passed to the jit-compiled outer function? I know that's probably not the answer you're looking for, but I think it would work...

aterenin commented 1 year ago

Thanks! Unfortunately, passing around nonzeroes won't work, since this is something computed by the package which should not be exposed to the user, and the functions it would need to be passed into are called by the user.

Should I file a bug report upstream? Would the TensorFlow repository be the right place?

jakevdp commented 1 year ago

It looks like one possible workaround for now is to use an optimization barrier:

from jax._src.ad_checkpoint import _optimization_barrier

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        nz = _optimization_barrier(nonzeroes)
        matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

This is still somewhat experimental, so unfortunately there is no public API for this.

jakevdp commented 1 year ago

That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla

aterenin commented 1 year ago

Thanks, that works! Two comments:

  1. The function _optimization_barrier must be called inside product, and not outside of it, so for instance nonzeroes = _optimization_barrier(sparse.eye(n).indices) will not work.
  2. This XLA bug might be specific to integer arrays: in my upstream codebase, I have other arrays which are also precomputed, but wrapping BCOO nonzero index arrays in _optimization_barrier is sufficient to get JIT to not freeze.

Very much appreciate your help with this!

mjsML commented 1 year ago

@aterenin, even though it seems like an XLA bug, it would be helpful to mention which hardware you are using (compilers have different backends, so code paths are different).

aterenin commented 1 year ago

@aterenin, even though it seems like an XLA bug, it would be helpful to mention which hardware you are using (compilers have different backends, so code paths are different).

Sure! Have reproduced this issue on both Nvidia GPU and Apple M1.

watkinrt commented 6 months ago

I know this thread is about a year old, but I thought I would note that I've run into similar issues with JITing large sparse arrays (in my case, for moderate sized finite element simulations - ~2M elements). Beyond slow constant folding, I've also run into the limit where XLA seg faults if the array is too large (generally somewhere around > 400,000,000 non-zero elements in the array). In this case, _optimization_barrier has no effect for me and the only solution is to provide the sparse array indices as an input to my function. I've run into this issue on both Linux and Windows. Due to the size of the arrays, I've only been able to try this on the CPU (as my GPUs don't have enough memory for problems this large).