google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.79k stars 2.72k forks source link

jax.lax.scan() segmentation fault with jax-metal on Mac M1 #20750

Open kdesoto-astro opened 5 months ago

kdesoto-astro commented 5 months ago

Description

Segmentation fault when calling jax.lax.scan(), jax.lax.map(), and related functions. Segmentation fault can be traced back to core.AxisPrimitive().bind() call. Reproducible using jax-metal=0.0.5 and jax-metal=0.0.4, and using either M1 or M2 MacBook Pro.

import jax
rng = jax.random.PRNGKey(0)
test_input = jax.random.normal(key=rng, shape=(5,5,5))
initial_state = jax.numpy.array(0.0)

def test_func(x, y):
    return x, y

x, y = jax.lax.scan(test_func, initial_state, test_input)

System info (python version, jaxlib version, accelerator, etc.)

Device: Apple M1 Pro (and M2) macOS: Sonoma 14.4 (and 14.5 Beta) jax-metal: 0.0.6 jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.11.0 (main, Mar 1 2023, 12:33:14) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', release='23.5.0', version='Darwin Kernel Version 23.5.0', machine='arm64')

carloswert commented 5 months ago

Thank you for posting, I have the same error I also added it to the Apple developer forum for the Apple team side (https://forums.developer.apple.com/forums/thread/750160)

twiecki commented 3 months ago

Running into the same issue, but getting an error:

Running window adaptation
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 688, in sample
    return _sample_external_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 351, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 567, in sample_jax_nuts
    raw_mcmc_samples, sample_stats, library = sampler_fn(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 413, in _sample_blackjax_nuts
    states, stats = map_fn(get_posterior_samples)(keys, initial_points)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 250, in _blackjax_inference_loop
    (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 334, in run
    last_state, info = jax.lax.scan(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 80, in wrapper_progress_bar
    _update_progress_bar(iter_num)
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 46, in _update_progress_bar
    _ = lax.cond(
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 48, in <lambda>
    lambda _: io_callback(_define_bar, None, iter_num),
  File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/jax/_src/callback.py", line 502, in io_callback
    out_flat = io_callback_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `EmitPythonCallback` not supported on METAL backend.
paullabonne commented 3 months ago

I also have the same issue with lax.scan() when trying to run on the metal GPU; it gives the segmentation fault or crashes the jupyter kernel. On the CPU with jax.config.update('jax_platform_name', 'cpu') it works fine.

tsumme1 commented 3 months ago

I am having the same issue as well. The scan function causes a seg fault on versions 0.0.5-7 with an M3 Max (os: Sonoma 14.5). I also found that the error is triggered when using values from xs (inp in the code below) inside the scanned function. The scan function still works when only using the carry. As a stopgap, I created the function below that avoids the error while functioning similarly.

import jax
from collections.abc import Iterable
from jax.tree import map as tree_map

def compat_scan(f,carry,xs,unroll=False,length=None):
    ind = jnp.zeros(1,jnp.uint32)
    def exec(c,inp):
        state,k = c
        if isinstance(xs,Iterable):
            vals = tree_map(lambda x: x[k][0],xs)
            state,out = f(state,vals)
        else:
            state,out = f(state,xs[k][0])
        k += jnp.uint32(1)
        return (state,k),out
    (carry,ind), ys = jax.lax.scan(exec,(carry,ind),xs,unroll=unroll,length=length)
    return carry, ys
shuhand0 commented 3 months ago

We are aware of the issue and working on a fix.

adam-hartshorne commented 3 months ago

This still doesn't appear to be fixed in 0.1.0

shuhand0 commented 3 months ago

The fix will be in next public OS and need OS upgrade.

bsarkar321 commented 1 month ago

Hi @shuhand0 . I upgraded to the next public OS released today (Darwin Kernel Version 23.6.0) and there is still a segmentation fault (both when testing on the original poster's library versions and the latest versions).

kdesoto-astro commented 2 weeks ago

@shuhand0 Is there an update on this? With the OS update and Jax-metal==0.1.0 the segfault still occurs. It's concerning that a core functionality of Jax has been broken for all Mac Apple Silicon users for the past 4 months, with no fix or workaround.

shuhand0 commented 2 weeks ago

The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?

bsarkar321 commented 2 weeks ago

The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?

This worked for me (no longer segfaulting)! I'm using MacOS 15 Beta 7, and validated this on both the original poster's library versions and the latest versions.

kdesoto-astro commented 2 weeks ago

This works for me too after updating to Sequoia - thank you!!