jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

segfault when using ensure_compile_time_eval #18831

Open benjaminvatterj opened 11 months ago

benjaminvatterj commented 11 months ago

Description

Hi! I have a model that requires precomputing a large matrix that is a model constant. To avoid having to compute it at every function call, I thought I'd make it evaluate at compile time. However, this results in a segmentation fault. The same does not happen if I do not ask for compile-time evaluation. I'm not sure if this is a bug or some deeper issue I don't fully understand. This is related to the following question I posed on the Q&A https://github.com/google/jax/discussions/18830.

Here's a minimal (a bit absurdly minimal admittedly) reproducable example

import jax.numpy as jnp
import jax

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}

@jax.jit
def segfault_version(x):
    with jax.ensure_compile_time_eval():
        segment_id = model_data['segment_id']
        val_id = model_data['val_id']
        num_segments = model_data['num_segments']
        agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
    x = x[val_id]
    return agg_mat @ x

x = jnp.ones(5)
print(segfault_version(x))

the output is simply a segmentation fault. I'm on a Mac Studio M2 (no jax metal, because too many things are broken)

What jax/jaxlib version are you using?

0.4.20 0.4.20

Which accelerator(s) are you using?

CPU

Additional system info?

1.26.2 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ] uname_result(system='Darwin', node='Bvatter-Studio-23', release='23.1.0', version='Darwin Kernel Version 23.1.0: Mon Oct 9 21:28:45 PDT 2023; root:xnu-10002.41.9~6/RELEASE_ARM64_T6020', machine='arm64')

NVIDIA GPU info

No response

benjaminvatterj commented 11 months ago

I'll add the observation that I get the same segmentation fault if I try to compute the agg_mat outside of the function and reference it within it.

import jax.numpy as jnp
import jax

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}

segment_id = model_data['segment_id']
val_id = model_data['val_id']
num_segments = model_data['num_segments']
agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
model_data['agg_mat'] = agg_mat

@jax.jit
def slow_version2(x):
    x = x[val_id]
    return agg_mat @ x

This also results in a segmentation fault. Is there a better way of passing large constants? The key issue is that its matrix, so passing it as a static argument to jit also doesn't seem to work.

rajasekharporeddy commented 1 month ago

Hi @benjaminvatterj

Thanks for reporting the question. JAX transformations (JIT, vmap etc.,) work only on pure functions. The provided function segfault_version() uses global variables, making it impure. This wouldn't work as expected under jit. Here's a modified version that defines a pure function suitable for JAX transformations:

import jax.numpy as jnp
import jax
from functools import partial

model_data = {
    'num_segments': 11544,
    'segment_id': jnp.repeat(jnp.arange(11544), 30),
    'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}

num_segments = model_data['num_segments']
segment_id = model_data['segment_id']
val_id = model_data['val_id']

@partial(jax.jit, static_argnums=(0,))
def segfault_version(num_segments, segment_id, val_id, x):
  with jax.ensure_compile_time_eval():
    agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
  x = x[val_id]
  return agg_mat @ x

x = jnp.ones(5)
segfault_version(num_segments, segment_id, val_id, x)

Output:

Array([30., 30., 30., ..., 30., 30., 30.], dtype=float32)

num_segments is passed as static argument because jnp.arange expects a concrete value instead of tracer value.

Please find the screenshot for reference:

image