openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.46k stars 374 forks source link

Very slow constant folding of very-large integer arrays, for instance when working with sparse matrices #1652

Open aterenin opened 1 year ago

aterenin commented 1 year ago

This relates to the JAX issue #14655: copying in various details from that thread below.

I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:

import jax
import jax.numpy as jnp
import jax.experimental.sparse as sparse
from jax.experimental.sparse import BCOO

n = 10000000

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

operator = build_sparse_linear_operator()

def fn(x):
    return operator(jnp.ones(n) / x).sum()

fn(1.0) # executes in 0.1s
jax.jit(fn)(1.0) # executes in almost one minute

Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:

Constant folding an instruction is taking > 8s:

  slice.22 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

Switching to the following JAX code bypasses the issue:

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        nz = _optimization_barrier(nonzeroes)
        matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

From this, the problem seems to be that XLA for some reason tries to constant-fold nonzeroes, in spite of its large size, and then runs out of resources while trying to do so. I haven't yet been able to replicate this for float arrays, so I'm not sure whether or not the issue is intger-specific.

tpopp commented 1 year ago

Do you have any ideas on what solution you would like to see?

Right now, XLA will constant fold operations on data under 4.5e7 elements, while this is 1e7 elements. This could be tweaked to potentially consider the number of operations being folded in addition to the number of elements, but this is a very context dependent choice, so in the end, this really needs to be decided by the user, through either the optimization barrier or some configuration that gives an upper bound to specify how much folding should be done. The upper bound choice is also very risky because it would be an easily overlooked flag that is not updated along with changes to the functions being compiled.

aterenin commented 1 year ago

Thanks!

Re: solution - I'm not sure, it probably merits a bit of thought.

I'd be happy to both add an optimization barrier to my code, and to provide configuration for max array size. I would also be happy to provide a max time limit for the computation, but this would lead to different results on different systems. As long as some of these options are available and accessible from JAX and I can stop my code from essentially freezing, this would be a big help.

Qazalbash commented 2 months ago

Hey there! I'm experiencing a similar issue. I've been trying to run some Bayesian inference with normalizing flows, which involves a massive MLP with over 88 million parameters. I'm guessing that JAX/XLA (not sure which one is the culprit) is taking way longer than expected. Can anyone help me out?

2024-05-16 01:52:31.238963: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %dot.19 = f32[10000,2048]{1,0} dot(f32[10000,2048]{1,0} %constant.217, f32[2048,2048]{1,0} %constant.174), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(log_posterior)/jit(main)/vmap(jit(log_posterior))/jit(log_likelihood)/jit(exp_rate)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/media/gradf/Academic/project/gwkokab-runs/o4_powerlawprimarymassratio/o4_powerlaw.py" source_line=117}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-05-16 01:52:32.331113: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2.091287591s
Constant folding an instruction is taking > 1s:

  %dot.19 = f32[10000,2048]{1,0} dot(f32[10000,2048]{1,0} %constant.217, f32[2048,2048]{1,0} %constant.174), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(log_posterior)/jit(main)/vmap(jit(log_posterior))/jit(log_likelihood)/jit(exp_rate)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/media/gradf/Academic/project/gwkokab-runs/o4_powerlawprimarymassratio/o4_powerlaw.py" source_line=117}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-05-16 01:52:35.025079: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:

  %dot.20 = f32[10000,2048]{1,0} dot(f32[10000,2048]{1,0} %constant.221, f32[2048,2048]{1,0} %constant.177), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(log_posterior)/jit(main)/vmap(jit(log_posterior))/jit(log_likelihood)/jit(exp_rate)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/media/gradf/Academic/project/gwkokab-runs/o4_powerlawprimarymassratio/o4_powerlaw.py" source_line=117}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-05-16 01:52:35.083874: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2.05891927s
Constant folding an instruction is taking > 2s:

  %dot.20 = f32[10000,2048]{1,0} dot(f32[10000,2048]{1,0} %constant.221, f32[2048,2048]{1,0} %constant.177), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(log_posterior)/jit(main)/vmap(jit(log_posterior))/jit(log_likelihood)/jit(exp_rate)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/media/gradf/Academic/project/gwkokab-runs/o4_powerlawprimarymassratio/o4_powerlaw.py" source_line=117}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.