Open kdesoto-astro opened 7 months ago
Thank you for posting, I have the same error I also added it to the Apple developer forum for the Apple team side (https://forums.developer.apple.com/forums/thread/750160)
Running into the same issue, but getting an error:
Running window adaptation
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 688, in sample
return _sample_external_nuts(
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 351, in _sample_external_nuts
idata = pymc_jax.sample_jax_nuts(
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 567, in sample_jax_nuts
raw_mcmc_samples, sample_stats, library = sampler_fn(
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 413, in _sample_blackjax_nuts
states, stats = map_fn(get_posterior_samples)(keys, initial_points)
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/pymc/sampling/jax.py", line 250, in _blackjax_inference_loop
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 334, in run
last_state, info = jax.lax.scan(
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 80, in wrapper_progress_bar
_update_progress_bar(iter_num)
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 46, in _update_progress_bar
_ = lax.cond(
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/blackjax/progress_bar.py", line 48, in <lambda>
lambda _: io_callback(_define_bar, None, iter_num),
File "/Users/twiecki/jax-metal/lib/python3.11/site-packages/jax/_src/callback.py", line 502, in io_callback
out_flat = io_callback_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `EmitPythonCallback` not supported on METAL backend.
I also have the same issue with lax.scan() when trying to run on the metal GPU; it gives the segmentation fault or crashes the jupyter kernel. On the CPU with jax.config.update('jax_platform_name', 'cpu') it works fine.
I am having the same issue as well. The scan function causes a seg fault on versions 0.0.5-7 with an M3 Max (os: Sonoma 14.5). I also found that the error is triggered when using values from xs
(inp
in the code below) inside the scanned function. The scan function still works when only using the carry. As a stopgap, I created the function below that avoids the error while functioning similarly.
import jax
from collections.abc import Iterable
from jax.tree import map as tree_map
def compat_scan(f,carry,xs,unroll=False,length=None):
ind = jnp.zeros(1,jnp.uint32)
def exec(c,inp):
state,k = c
if isinstance(xs,Iterable):
vals = tree_map(lambda x: x[k][0],xs)
state,out = f(state,vals)
else:
state,out = f(state,xs[k][0])
k += jnp.uint32(1)
return (state,k),out
(carry,ind), ys = jax.lax.scan(exec,(carry,ind),xs,unroll=unroll,length=length)
return carry, ys
We are aware of the issue and working on a fix.
This still doesn't appear to be fixed in 0.1.0
The fix will be in next public OS and need OS upgrade.
Hi @shuhand0 . I upgraded to the next public OS released today (Darwin Kernel Version 23.6.0) and there is still a segmentation fault (both when testing on the original poster's library versions and the latest versions).
@shuhand0 Is there an update on this? With the OS update and Jax-metal==0.1.0 the segfault still occurs. It's concerning that a core functionality of Jax has been broken for all Mac Apple Silicon users for the past 4 months, with no fix or workaround.
The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?
The fix is within the MetalPerformanceShaderGraph Framework in MacOS Sequoia. Could you try the test on the latest MacOS 15 Beta 7?
This worked for me (no longer segfaulting)! I'm using MacOS 15 Beta 7, and validated this on both the original poster's library versions and the latest versions.
This works for me too after updating to Sequoia - thank you!!
Description
Segmentation fault when calling jax.lax.scan(), jax.lax.map(), and related functions. Segmentation fault can be traced back to core.AxisPrimitive().bind() call. Reproducible using jax-metal=0.0.5 and jax-metal=0.0.4, and using either M1 or M2 MacBook Pro.
System info (python version, jaxlib version, accelerator, etc.)
Device: Apple M1 Pro (and M2) macOS: Sonoma 14.4 (and 14.5 Beta) jax-metal: 0.0.6 jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.11.0 (main, Mar 1 2023, 12:33:14) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', release='23.5.0', version='Darwin Kernel Version 23.5.0', machine='arm64')