jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

error: failed to legalize operation 'mhlo.erf' with jax-metal #23384

Open a-gn opened 2 months ago

a-gn commented 2 months ago

Description

This MLP fails to run with jax-metal:

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.typing as jt

class MLP(nn.Module):
    mid_features: tuple[int, ...]
    out_features: int

    @nn.compact
    def __call__(
        self,
        x: jt.ArrayLike,
    ):
        for out_feature_count in self.mid_features:
            x = nn.Dense(out_feature_count)(x)
            x = nn.relu(x)
        return nn.Dense(self.out_features)(x)

mlp = MLP((64, 64, 64), 6)
prng_key = jax.random.key(7)
params = mlp.init(prng_key, jnp.ones((2, 32)))
data = jax.random.uniform(prng_key, (4, 32), float, -10000, 10000)
print(mlp.apply(params, data))

with jax-metal installed:

arno@mba-2 ~/p/reimpl (main) [1]> python test/test_mlp.py 
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302902.052601 5204059 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1725302902.081618 5204059 service.cc:145] XLA service 0x11fd1e040 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302902.081748 5204059 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302902.083499 5204059 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302902.083613 5204059 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
Traceback (most recent call last):
  File "/Users/arno/projects/reimpl/test/test_mlp.py", line 24, in <module>
    params = mlp.init(prng_key, jnp.ones((2, 32)))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/projects/reimpl/test/test_mlp.py", line 17, in __call__
    x = nn.Dense(out_feature_count)(x)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py", line 256, in __call__
    kernel = self.param(
             ^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/nn/initializers.py", line 335, in init
    return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/arno/venv/rs311/lib/python3.11/site-packages/jax/_src/random.py", line 831, in truncated_normal
    return _truncated_normal(key, lower, upper, shape, dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: error: failed to legalize operation 'mhlo.erf'
      value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1889:8: note: called from
    v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs)
       ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/linear.py:256:13: note: called from
    kernel = self.param(
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
          y = run_fun(self, *args, **kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
      return self._call_wrapped_method(fun, args, kwargs)
            ^
/Users/arno/projects/reimpl/test/test_mlp.py:17:16: note: called from
            x = nn.Dense(out_feature_count)(x)
               ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:1233:14: note: called from
          y = run_fun(self, *args, **kwargs)
             ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:701:13: note: called from
      return self._call_wrapped_method(fun, args, kwargs)
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/linen/module.py:3103:13: note: called from
      return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
            ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:1101:10: note: called from
      y = fn(root, *args, **kwargs)
         ^
/Users/arno/venv/rs311/lib/python3.11/site-packages/flax/core/scope.py:990:14: note: see current operation: %109 = "mhlo.erf"(%108) : (tensor<f32>) -> tensor<f32>
      value = init_fn(self.make_rng('params'), *init_args, **init_kwargs)
             ^

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

without jax-metal:

arno@mba-2 ~/p/reimpl (main)> python test/test_mlp.py
[[  636.52386   882.25415  2988.575     192.267     866.60596 -1293.6633 ]
 [ 4585.539    2880.368    2868.1316     78.3667   1750.4458   -857.1804 ]
 [  988.6698   2208.9604   2891.1992   -431.1714    776.54626  -211.66962]
 [  979.84326  4824.3716   6499.3325    321.2257   1804.8367    336.66034]]

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.9 (v3.11.9:de54cf5be3, Apr  2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1725302834.712529 5202857 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1725302834.728122 5202857 service.cc:145] XLA service 0x116f5d340 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1725302834.728249 5202857 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1725302834.729747 5202857 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1725302834.729766 5202857 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.0.1
python: 3.11.9 (v3.11.9:de54cf5be3, Apr  2 2024, 07:12:50) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='mba-2.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:21 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T8103', machine='arm64')
TheSkyentist commented 1 month ago

Here is a minimal working example which I think gets to the fact that the error function has an issue in jax-metal.

#! /usr/bin/env python

from jax.scipy.special import erf

erf(0)

Results in:

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: ./test.py:5:0: error: failed to legalize operation 'mhlo.erf' ./test.py:5:0: note: see current operation: %0 = "mhlo.erf"(%arg0) : (tensor<f32>) -> tensor<f32>

Using: python 3.12.6 jax 0.4.31 jaxlib 0.4.31 jax_metal 0.1.0

a-gn commented 1 month ago

As I said in a discussion on this repo, I'd like to try to contribute these operations, but I don't know how to contribute to jax-metal. The source doesn't seem to be public?

shuhand0 commented 1 month ago

The source is not open. The 'erf' has bee fixed in 0.1.1 patch.