Open benjaminvatterj opened 11 months ago
I'll add the observation that I get the same segmentation fault if I try to compute the agg_mat outside of the function and reference it within it.
import jax.numpy as jnp
import jax
model_data = {
'num_segments': 11544,
'segment_id': jnp.repeat(jnp.arange(11544), 30),
'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}
segment_id = model_data['segment_id']
val_id = model_data['val_id']
num_segments = model_data['num_segments']
agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
model_data['agg_mat'] = agg_mat
@jax.jit
def slow_version2(x):
x = x[val_id]
return agg_mat @ x
This also results in a segmentation fault. Is there a better way of passing large constants? The key issue is that its matrix, so passing it as a static argument to jit also doesn't seem to work.
Hi @benjaminvatterj
Thanks for reporting the question. JAX transformations (JIT, vmap etc.,) work only on pure functions. The provided function segfault_version()
uses global variables
, making it impure. This wouldn't work as expected under jit
. Here's a modified version that defines a pure function suitable for JAX transformations:
import jax.numpy as jnp
import jax
from functools import partial
model_data = {
'num_segments': 11544,
'segment_id': jnp.repeat(jnp.arange(11544), 30),
'val_id': jnp.tile(jnp.arange(10), 3 * 11544),
}
num_segments = model_data['num_segments']
segment_id = model_data['segment_id']
val_id = model_data['val_id']
@partial(jax.jit, static_argnums=(0,))
def segfault_version(num_segments, segment_id, val_id, x):
with jax.ensure_compile_time_eval():
agg_mat = (jnp.arange(num_segments).reshape(-1, 1) == segment_id.reshape(1, -1))
x = x[val_id]
return agg_mat @ x
x = jnp.ones(5)
segfault_version(num_segments, segment_id, val_id, x)
Output:
Array([30., 30., 30., ..., 30., 30., 30.], dtype=float32)
num_segments
is passed as static argument because jnp.arange
expects a concrete value instead of tracer value.
Please find the screenshot for reference:
Description
Hi! I have a model that requires precomputing a large matrix that is a model constant. To avoid having to compute it at every function call, I thought I'd make it evaluate at compile time. However, this results in a segmentation fault. The same does not happen if I do not ask for compile-time evaluation. I'm not sure if this is a bug or some deeper issue I don't fully understand. This is related to the following question I posed on the Q&A https://github.com/google/jax/discussions/18830.
Here's a minimal (a bit absurdly minimal admittedly) reproducable example
the output is simply a segmentation fault. I'm on a Mac Studio M2 (no jax metal, because too many things are broken)
What jax/jaxlib version are you using?
0.4.20 0.4.20
Which accelerator(s) are you using?
CPU
Additional system info?
1.26.2 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ] uname_result(system='Darwin', node='Bvatter-Studio-23', release='23.1.0', version='Darwin Kernel Version 23.1.0: Mon Oct 9 21:28:45 PDT 2023; root:xnu-10002.41.9~6/RELEASE_ARM64_T6020', machine='arm64')
NVIDIA GPU info
No response