jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Apache License 2.0
29.98k stars 2.75k forks source link

Metal: add support for complex-valued math #16416

Open hawkinsp opened 1 year ago

hawkinsp commented 1 year ago



In [1]: import jax

In [2]: jax.lax.add(1+2j, 3+4j)
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

loc("-":3:10): error: 'mps.add' op operand #0 must be tensor of MPS type values, but got 'tensor<complex<f32>>'
Segmentation fault: 11

It's fine that the plugin doesn't support complex numbers, but it shouldn't segfault.

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.2

Which accelerator(s) are you using?

Apple GPU

Additional system info

No response


No response

Saladino93 commented 1 year ago

Following this one. @hawkinsp any news?

hawkinsp commented 1 year ago

@Saladino93 This is an issue in the Metal plugin, so it can only be fixed by Apple. The right people to take a look are already assigned to the bug.

shuhand0 commented 1 year ago

Metal plugin doesn't support complex yet, but we will support erroring out more gracefully.

FilipeMaia commented 11 months ago

Any plans to support complex numbers in the near future?

shuhand0 commented 11 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

FilipeMaia commented 11 months ago

I'm doing X-ray scattering simulations, which require wave interference calculations using complex numbers. This should be common to all electromagnetic simulations. Another extremely important function to support requiring complex numbers are FFTs.

denehoffman commented 11 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

I'd like to chime in here and say that this would also be a huge plus for my work in particle physics, I do a lot of amplitude analysis and quantum amplitudes are typically complex (in particular, lots of spherical harmonics and Breit-Wigners). Universal support of complex numbers on Metal would help me pitch this library more to my collaboration. Although we typically use large computer clusters which don't run Apple Silicon, there is certainly demand for being able to run some of these things on laptops, and almost everyone I know in this field uses M1/M2 Macs for their personal computers for some reason.

Is somewhere I can follow the progress on this? Is jax-metal being developed in an open-source format? I can't find much through the Apple developer forums.

joshuazlin commented 8 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

If it's helpful - I ran into this today as well, also as a physicist. Writing some code for lattice-QCD applications, where gauge-fields and all derived quantities are complex-valued (all production code is run on clusters, but local testing on apple silicon with jax-metal would be nice)

shuhand0 commented 8 months ago

For the complex element types, do you require fp64, or fp32 would be good for your jax applications?

FilipeMaia commented 8 months ago

Single precision would be good enough for me.

abocquet commented 8 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

Hello, I'm a member of the development team for dynamiqs (GitHub - dynamiqs), a python package designed for quantum mechanics simulations. Our package heavily depends on complex number computations. Specifically, we mostly perform complex matrix-matrix multiplications and utilize various linear algebra routines like eigh, expm and schur. We are in the process of switching from PyTorch to JAX and the support of complex numbers on Apple Silicon for JAX would be a huge plus.

Most simulation can be run in single precision but double precision would be required to run high accuracy simulations as we sometimes need.

EDIT: updated link

rajasekharporeddy commented 3 months ago

Hi @hawkinsp

Using jax-metal 0.0.7 with jax 0.4.28 and jaxlib 0.4.28, the mentioned code now encounters an XLARuntimeError instead of segmentation fault: 11.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.


jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

Thank you.

hawkinsp commented 3 months ago

Thanks @rajasekharporeddy . I think we can now redesignate this as an "enhancement", not a "bug". It's a bug if the plugin crashes or produces wrong output. It's now just a missing feature.

akelman commented 3 weeks ago

Chiming in to add that I would also greatly appreciate support for complex numbers. I also do lattice QCD (physics simulations), most of which runs on clusters, but local testing would be a huge plus. In my use case, FP64 support is crucial.