jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Apple Silicon: error: failed to legalize operation 'mhlo.custom_call' #16287

Open mehdiataei opened 1 year ago

mehdiataei commented 1 year ago

Description

The mhlo.custom_call operation in the create_grid_connectivity_bitmask function failed to legalize, leading to an XlaRuntimeError.

  File "XLB/src/base.py", line 163, in _create_boundary_data
    connectivity_bitmask = self.create_grid_connectivity_bitmask(solid_halo_voxels)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN:XLB/src/base.py:209:0: error: failed to legalize operation 'mhlo.custom_call'
XLB/src/base.py:239:0: note: called from
XLB/src/base.py:209:0: note: see current operation: %42 = "mhlo.custom_call"(%1) {api_version = 1 : i32, backend_config = "", call_target_name = "Sharding", called_computations = [], has_side_effect = false, mhlo.sharding = "{maximal device=0}"} : (tensor<102x102x102x19xi1>) -> tensor<102x102x102x19xi1>

Chipset: Apple M1 Pro JAX version: installation via instructions: https://developer.apple.com/metal/jax/

Repro:

# Intstall JAX via: https://developer.apple.com/metal/jax/
# Install XLB depenencies
pip3 install jmp==0.0.4 matplotlib==3.7.1 numpy==1.24.2 pyvista==0.38.5 Rtree==1.0.1 trimesh==3.20.2 orbax-checkpoint==0.2.4 termcolor==2.3.0

git clone https://github.com/Autodesk/XLB
cd XLB
export PYTHONPATH=.
python3 examples/MLUPS3d.py 100 100 
# In near future the file will be moved to  examples/performance/MLUPS3d.py

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

@kulinseth

hawkinsp commented 1 year ago

I should note: since AIUI the metal plugin only supports single GPUs, all shardings are necessarily trivial and do nothing. So the correct action is to ignore any Sharding custom-calls.

(Also resharding should be a first-class stable HLO operator, not a custom-call, but that's for future work.)

mehdiataei commented 1 year ago

Anyway to ignore this in the code without removing the shardings?

hawkinsp commented 1 year ago

@mehdiataei There are three options here: a) Apple teaches the Metal plugin how to handle these. b) We add a special case to JAX to strip shardings out when using the Metal plugin. c) We teach JAX not to emit sharding constraints when there is only one device and they are degenerate.

I'm going to reject (b) for now. It might be reasonable for us to do both (a) and (c). Certainly (a) is sensible no matter what.

mehdiataei commented 1 year ago

I think c is the best option, and I have noticed small performance gains if I explicitly remove shardings when using 1 device.

P.S. I am not sure if it is whether due to the removal of the sharding constraints, or ppermute operation in shardmap.

kulinseth commented 1 year ago

I should note: since AIUI the metal plugin only supports single GPUs, all shardings are necessarily trivial and do nothing. So the correct action is to ignore any Sharding custom-calls.

(Also resharding should be a first-class stable HLO operator, not a custom-call, but that's for future work.)

Thanks @hawkinsp for following up with options. custom-call support is indeed not supported in the current version of Metal plugin. We have definitely seen the custom call used across networks, so there is definitely value in adding it. We will update here when we have the support.

liamclarkza commented 6 months ago

Hi guys, are there any further updates on getting support for custom-call? Currently experiencing the same issues as above with jax-metal-0.0.6 and jaxlib-0.4.23 when using an M3 Pro.

subho406 commented 5 months ago

Seems like this is causing flax==0.8.4 to fail. For instance, running https://github.com/Reytuag/transformerXL_PPO_JAX, gives me the following error:

I0000 00:00:1717204367.698949 1351660 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/subho/Documents/magic/transformerXL_PPO_JAX/train_PPO_trXL.py", line 56, in <module>
    out = train_jit(rng)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: error: failed to legalize operation 'mhlo.custom_call'
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1889:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/linear.py:256:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:701:0: note: called from
/Users/subho/Documents/magic/transformerXL_PPO_JAX/trainer_PPO_trXL.py:78:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:701:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/linen/module.py:3103:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:1101:0: note: called from
/Users/subho/base/lib/python3.10/site-packages/flax/core/scope.py:990:0: note: see current operation: %32856:2 = "mhlo.custom_call"(%32855) {api_version = 1 : i32, backend_config = "", call_target_name = "Qr", called_computations = [], has_side_effect = false} : (tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256xf32>)

Wondering if there have been any updates.

TylerMclaughlin commented 2 months ago

Here's a pretty minimal repro, environment info below the error message:

import jax
import jax.numpy as np
import flax.linen as nn

model = nn.Dense(features=5)
params = model.init(jax.random.key(0), jnp.ones((5,5))) 
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[102], line 2
      1 model = nn.Dense(features=5)
----> 2 params = model.init(jax.random.key(0), jnp.ones((3,3))) 

    [... skipping hidden 9 frame]

File ~/jax-metal/lib/python3.9/site-packages/flax/linen/linear.py:256, in Dense.__call__(self, inputs)
    246 @compact
    247 def __call__(self, inputs: Array) -> Array:
    248   """Applies a linear transformation to the inputs along the last dimension.
    249 
    250   Args:
   (...)
    254     The transformed input.
    255   """
--> 256   kernel = self.param(
    257     'kernel',
    258     self.kernel_init,
    259     (jnp.shape(inputs)[-1], self.features),
    260     self.param_dtype,
    261   )
    262   if self.use_bias:
    263     bias = self.param(
    264       'bias', self.bias_init, (self.features,), self.param_dtype
    265     )

    [... skipping hidden 2 frame]

File ~/jax-metal/lib/python3.9/site-packages/jax/_src/nn/initializers.py:335, in variance_scaling.<locals>.init(key, shape, dtype)
    332 if jnp.issubdtype(dtype, jnp.floating):
    333   # constant is stddev of standard normal truncated to (-2, 2)
    334   stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
--> 335   return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
    336 else:
    337   # constant is stddev of complex standard normal truncated to 2
    338   stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)

File ~/jax-metal/lib/python3.9/site-packages/jax/_src/random.py:852, in truncated_normal(key, lower, upper, shape, dtype)
    850 if shape is not None:
    851   shape = core.as_named_shape(shape)
--> 852 return _truncated_normal(key, lower, upper, shape, dtype)

    [... skipping hidden 14 frame]

File ~/jax-metal/lib/python3.9/site-packages/jax/_src/compiler.py:238, in backend_compile(backend, module, options, host_callbacks)
    233   return backend.compile(built_c, compile_options=options,
    234                          host_callbacks=host_callbacks)
    235 # Some backends don't have `host_callbacks` option yet
    236 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237 # to take in `host_callbacks`
--> 238 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: /Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:990:0: error: failed to legalize operation 'mhlo.erf'
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:1889:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/linear.py:256:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:1233:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:701:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:3103:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:1101:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:1137:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:2316:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/linen/module.py:2464:0: note: called from
/Users/rt/jax-metal/lib/python3.9/site-packages/flax/core/scope.py:990:0: note: see current operation: %111 = "mhlo.erf"(%110) : (tensor<f32>) -> tensor<f32>

Here's my environment info: jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.4 python: 3.9.6 (default, May 7 2023, 23:32:44) [Clang 14.0.3 (clang-1403.0.22.14.1)] jax.devices (1 total, 1 local): [METAL(id=0)]

Flax: 0.8.5. Apple M2 Pro, Mac OSX Sonoma 14.6;1.

twiecki commented 2 months ago

Upgraded to OSX Sequioa 15 Beta 7 and now getting with numpyro sampler:

ile ~/micromamba/envs/pymc5_new/lib/python3.12/site-packages/numpyro/util.py:400, in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    398 with tqdm.trange(upper) as t:
    399     for i in t:
--> 400         vals = _body_fn(i, *vals)
    402         t.set_description(progbar_desc(i), refresh=False)
    403         if diagnostics_fn:

    [... skipping hidden 14 frame]

File ~/micromamba/envs/pymc5_new/lib/python3.12/site-packages/jax/_src/compiler.py:267, in backend_compile(backend, module, options, host_callbacks)
    262   return backend.compile(built_c, compile_options=options,
    263                          host_callbacks=host_callbacks)
    264 # Some backends don't have `host_callbacks` option yet
    265 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    266 # to take in `host_callbacks`
--> 267 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: <ipython-input-1-5ac960273fe0>:9:12: error: failed to legalize operation 'mhlo.popcnt'
<ipython-input-3-84e33ccd7f24>:9:12: note: called from
/Users/twiecki/micromamba/envs/pymc5_new/bin/ipython:10:13: note: called from
    sys.exit(start_ipython())
            ^
<ipython-input-1-5ac960273fe0>:9:12: note: see current operation: %3616 = "mhlo.popcnt"(%3615) : (tensor<si32>) -> tensor<si32>
shuhand0 commented 2 months ago

The original post for not supporting CustomCall: jax-metal don't support all custom call targets and may not be in near future. Target Sharding is supported with identity and target 'Qr' is not supported. mhlo.popcnt is not supported so far.