Open a-gn opened 2 months ago
Here is a minimal working example which I think gets to the fact that the error function has an issue in jax-metal.
#! /usr/bin/env python
from jax.scipy.special import erf
erf(0)
Results in:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: ./test.py:5:0: error: failed to legalize operation 'mhlo.erf' ./test.py:5:0: note: see current operation: %0 = "mhlo.erf"(%arg0) : (tensor<f32>) -> tensor<f32>
Using: python 3.12.6 jax 0.4.31 jaxlib 0.4.31 jax_metal 0.1.0
As I said in a discussion on this repo, I'd like to try to contribute these operations, but I don't know how to contribute to jax-metal. The source doesn't seem to be public?
The source is not open. The 'erf' has bee fixed in 0.1.1 patch.
Description
This MLP fails to run with jax-metal:
with jax-metal installed:
without jax-metal:
System info (python version, jaxlib version, accelerator, etc.)