Open mehdiataei opened 1 year ago
@kulinseth
I should note: since AIUI the metal plugin only supports single GPUs, all shardings are necessarily trivial and do nothing. So the correct action is to ignore any Sharding
custom-call
s.
(Also resharding should be a first-class stable HLO operator, not a custom-call
, but that's for future work.)
Anyway to ignore this in the code without removing the shardings?
@mehdiataei There are three options here: a) Apple teaches the Metal plugin how to handle these. b) We add a special case to JAX to strip shardings out when using the Metal plugin. c) We teach JAX not to emit sharding constraints when there is only one device and they are degenerate.
I'm going to reject (b) for now. It might be reasonable for us to do both (a) and (c). Certainly (a) is sensible no matter what.
I think c is the best option, and I have noticed small performance gains if I explicitly remove shardings when using 1 device.
P.S. I am not sure if it is whether due to the removal of the sharding constraints, or ppermute operation in shardmap.
I should note: since AIUI the metal plugin only supports single GPUs, all shardings are necessarily trivial and do nothing. So the correct action is to ignore any
Sharding
custom-call
s.(Also resharding should be a first-class stable HLO operator, not a
custom-call
, but that's for future work.)
Thanks @hawkinsp for following up with options. custom-call
support is indeed not supported in the current version of Metal plugin. We have definitely seen the custom call used across networks, so there is definitely value in adding it. We will update here when we have the support.
Hi guys, are there any further updates on getting support for custom-call
? Currently experiencing the same issues as above with jax-metal-0.0.6
and jaxlib-0.4.23
when using an M3 Pro.
Seems like this is causing flax==0.8.4
to fail. For instance, running https://github.com/Reytuag/transformerXL_PPO_JAX, gives me the following error:
I0000 00:00:1717204367.698949 1351660 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/subho/Documents/magic/transformerXL_PPO_JAX/train_PPO_trXL.py", line 56, in <module>
out = train_jit(rng)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: error: failed to legalize operation 'mhlo.custom_call'
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1889:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/linear.py:256:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:701:0: note: called from
/Users/subho/Documents/magic/transformerXL_PPO_JAX/trainer_PPO_trXL.py:78:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:701:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:3103:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:1101:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: note: see current operation: %32856:2 = "mhlo.custom_call"(%32855) {api_version = 1 : i32, backend_config = "", call_target_name = "Qr", called_computations = [], has_side_effect = false} : (tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256xf32>)
Wondering if there have been any updates.
Here's a pretty minimal repro, environment info below the error message:
import jax
import jax.numpy as np
import flax.linen as nn
model = nn.Dense(features=5)
params = model.init(jax.random.key(0), jnp.ones((5,5)))
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[102], line 2
1 model = nn.Dense(features=5)
----> 2 params = model.init(jax.random.key(0), jnp.ones((3,3)))
[... skipping hidden 9 frame]
File ~/jax-metal/lib/python3.9/site-packages/flax/linen/linear.py:256, in Dense.__call__(self, inputs)
246 @compact
247 def __call__(self, inputs: Array) -> Array:
248 """Applies a linear transformation to the inputs along the last dimension.
249
250 Args:
(...)
254 The transformed input.
255 """
--> 256 kernel = self.param(
257 'kernel',
258 self.kernel_init,
259 (jnp.shape(inputs)[-1], self.features),
260 self.param_dtype,
261 )
262 if self.use_bias:
263 bias = self.param(
264 'bias', self.bias_init, (self.features,), self.param_dtype
265 )
[... skipping hidden 2 frame]
File ~/jax-metal/lib/python3.9/site-packages/jax/_src/nn/initializers.py:335, in variance_scaling.<locals>.init(key, shape, dtype)
332 if jnp.issubdtype(dtype, jnp.floating):
333 # constant is stddev of standard normal truncated to (-2, 2)
334 stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
--> 335 return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
336 else:
337 # constant is stddev of complex standard normal truncated to 2
338 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
File ~/jax-metal/lib/python3.9/site-packages/jax/_src/random.py:852, in truncated_normal(key, lower, upper, shape, dtype)
850 if shape is not None:
851 shape = core.as_named_shape(shape)
--> 852 return _truncated_normal(key, lower, upper, shape, dtype)
[... skipping hidden 14 frame]
File ~/jax-metal/lib/python3.9/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
233 return backend.compile(built_c, compile_options=options,
234 host_callbacks=host_callbacks)
235 # Some backends don't have `host_callbacks` option yet
236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: /Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:990:0: error: failed to legalize operation 'mhlo.erf'
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:1889:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/linear.py:256:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:701:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:3103:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:1101:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:1137:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:2316:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:2464:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:990:0: note: see current operation: %111 = "mhlo.erf"(%110) : (tensor<f32>) -> tensor<f32>
Here's my environment info: jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.4 python: 3.9.6 (default, May 7 2023, 23:32:44) [Clang 14.0.3 (clang-1403.0.22.14.1)] jax.devices (1 total, 1 local): [METAL(id=0)]
Flax: 0.8.5. Apple M2 Pro, Mac OSX Sonoma 14.6;1.
Upgraded to OSX Sequioa 15 Beta 7 and now getting with numpyro sampler:
ile ~/micromamba/envs/pymc5_new/lib/python3.12/site-packages/numpyro/util.py:400, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
398 with tqdm.trange(upper) as t:
399 for i in t:
--> 400 vals = _body_fn(i, *vals)
402 t.set_description(progbar_desc(i), refresh=False)
403 if diagnostics_fn:
[... skipping hidden 14 frame]
File ~/micromamba/envs/pymc5_new/lib/python3.12/site-packages/jax/_src/compiler.py:267, in backend_compile(backend, module, options, host_callbacks)
262 return backend.compile(built_c, compile_options=options,
263 host_callbacks=host_callbacks)
264 # Some backends don't have `host_callbacks` option yet
265 # TODO(sharadmv): remove this fallback when all backends allow `compile`
266 # to take in `host_callbacks`
--> 267 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: <ipython-input-1-5ac960273fe0>:9:12: error: failed to legalize operation 'mhlo.popcnt'
<ipython-input-3-84e33ccd7f24>:9:12: note: called from
/Users/twiecki/micromamba/envs/pymc5_new/bin/ipython:10:13: note: called from
sys.exit(start_ipython())
^
<ipython-input-1-5ac960273fe0>:9:12: note: see current operation: %3616 = "mhlo.popcnt"(%3615) : (tensor<si32>) -> tensor<si32>
The original post for not supporting CustomCall: jax-metal don't support all custom call targets and may not be in near future. Target Sharding is supported with identity and target 'Qr' is not supported. mhlo.popcnt is not supported so far.
Description
The mhlo.custom_call operation in the create_grid_connectivity_bitmask function failed to legalize, leading to an XlaRuntimeError.
Chipset: Apple M1 Pro JAX version: installation via instructions: https://developer.apple.com/metal/jax/
Repro:
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response