Open hawkinsp opened 1 year ago
Following this one. @hawkinsp any news?
@Saladino93 This is an issue in the Metal plugin, so it can only be fixed by Apple. The right people to take a look are already assigned to the bug.
Metal plugin doesn't support complex yet, but we will support erroring out more gracefully.
Any plans to support complex numbers in the near future?
Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?
I'm doing X-ray scattering simulations, which require wave interference calculations using complex numbers. This should be common to all electromagnetic simulations. Another extremely important function to support requiring complex numbers are FFTs.
Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?
I'd like to chime in here and say that this would also be a huge plus for my work in particle physics, I do a lot of amplitude analysis and quantum amplitudes are typically complex (in particular, lots of spherical harmonics and Breit-Wigners). Universal support of complex numbers on Metal would help me pitch this library more to my collaboration. Although we typically use large computer clusters which don't run Apple Silicon, there is certainly demand for being able to run some of these things on laptops, and almost everyone I know in this field uses M1/M2 Macs for their personal computers for some reason.
Is somewhere I can follow the progress on this? Is jax-metal being developed in an open-source format? I can't find much through the Apple developer forums.
Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?
If it's helpful - I ran into this today as well, also as a physicist. Writing some code for lattice-QCD applications, where gauge-fields and all derived quantities are complex-valued (all production code is run on clusters, but local testing on apple silicon with jax-metal would be nice)
For the complex element types, do you require fp64, or fp32 would be good for your jax applications?
Single precision would be good enough for me.
Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?
Hello, I'm a member of the development team for dynamiqs (GitHub - dynamiqs), a python package designed for quantum mechanics simulations. Our package heavily depends on complex number computations. Specifically, we mostly perform complex matrix-matrix multiplications and utilize various linear algebra routines like eigh
, expm
and schur
. We are in the process of switching from PyTorch to JAX and the support of complex numbers on Apple Silicon for JAX would be a huge plus.
Most simulation can be run in single precision but double precision would be required to run high accuracy simulations as we sometimes need.
EDIT: updated link
Hi @hawkinsp
Using jax-metal 0.0.7
with jax 0.4.28 and jaxlib 0.4.28, the mentioned code now encounters an XLARuntimeError
instead of segmentation fault: 11
.
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
"func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation:
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
"func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
jax.print_environment_info()
:
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
Thank you.
Thanks @rajasekharporeddy . I think we can now redesignate this as an "enhancement", not a "bug". It's a bug if the plugin crashes or produces wrong output. It's now just a missing feature.
Chiming in to add that I would also greatly appreciate support for complex numbers. I also do lattice QCD (physics simulations), most of which runs on clusters, but local testing would be a huge plus. In my use case, FP64 support is crucial.
It's essential for physics & ML-in-physics applications to have complex support at both complex64 and complex128 levels. What is the timescale on this being implemented in metal?
https://github.com/rafael-fuente/diffractsim/ uses np.float (float64) and np.complex (complex128) dtypes, neither of which are supported for the JAX GPU backend via jax-metal.
Even without jax-metal, setting the backend to jax.numpy still provides a substantial speedup over numpy.
Yep - anything involving optics or wave physics more or less requires complex dtypes.
Just chiming in to note that popular linear complexity transformer alternatives like Mamba and LRU also use complex numbers.
Description
Repro:
It's fine that the plugin doesn't support complex numbers, but it shouldn't segfault.
What jax/jaxlib version are you using?
jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.2
Which accelerator(s) are you using?
Apple GPU
Additional system info
No response
NVIDIA GPU info
No response