jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

Metal: add support for complex-valued math #16416

Open hawkinsp opened 1 year ago

hawkinsp commented 1 year ago

Description

Repro:

In [1]: import jax

In [2]: jax.lax.add(1+2j, 3+4j)
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

loc("-":3:10): error: 'mps.add' op operand #0 must be tensor of MPS type values, but got 'tensor<complex<f32>>'
Segmentation fault: 11

It's fine that the plugin doesn't support complex numbers, but it shouldn't segfault.

What jax/jaxlib version are you using?

jax 0.4.11, jaxlib 0.4.10, jax-metal 0.0.2

Which accelerator(s) are you using?

Apple GPU

Additional system info

No response

NVIDIA GPU info

No response

Saladino93 commented 1 year ago

Following this one. @hawkinsp any news?

hawkinsp commented 1 year ago

@Saladino93 This is an issue in the Metal plugin, so it can only be fixed by Apple. The right people to take a look are already assigned to the bug.

shuhand0 commented 1 year ago

Metal plugin doesn't support complex yet, but we will support erroring out more gracefully.

FilipeMaia commented 1 year ago

Any plans to support complex numbers in the near future?

shuhand0 commented 1 year ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

FilipeMaia commented 1 year ago

I'm doing X-ray scattering simulations, which require wave interference calculations using complex numbers. This should be common to all electromagnetic simulations. Another extremely important function to support requiring complex numbers are FFTs.

denehoffman commented 1 year ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

I'd like to chime in here and say that this would also be a huge plus for my work in particle physics, I do a lot of amplitude analysis and quantum amplitudes are typically complex (in particular, lots of spherical harmonics and Breit-Wigners). Universal support of complex numbers on Metal would help me pitch this library more to my collaboration. Although we typically use large computer clusters which don't run Apple Silicon, there is certainly demand for being able to run some of these things on laptops, and almost everyone I know in this field uses M1/M2 Macs for their personal computers for some reason.

Is somewhere I can follow the progress on this? Is jax-metal being developed in an open-source format? I can't find much through the Apple developer forums.

joshuazlin commented 11 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

If it's helpful - I ran into this today as well, also as a physicist. Writing some code for lattice-QCD applications, where gauge-fields and all derived quantities are complex-valued (all production code is run on clusters, but local testing on apple silicon with jax-metal would be nice)

shuhand0 commented 10 months ago

For the complex element types, do you require fp64, or fp32 would be good for your jax applications?

FilipeMaia commented 10 months ago

Single precision would be good enough for me.

abocquet commented 10 months ago

Complex support has been added to our backend stack, and jax-metal will integrate the change. For us to provide a good coverage, could you share with us what kind of applications you'd like to accelerate on jax-metal with complex support?

Hello, I'm a member of the development team for dynamiqs (GitHub - dynamiqs), a python package designed for quantum mechanics simulations. Our package heavily depends on complex number computations. Specifically, we mostly perform complex matrix-matrix multiplications and utilize various linear algebra routines like eigh, expm and schur. We are in the process of switching from PyTorch to JAX and the support of complex numbers on Apple Silicon for JAX would be a huge plus.

Most simulation can be run in single precision but double precision would be required to run high accuracy simulations as we sometimes need.

EDIT: updated link

rajasekharporeddy commented 6 months ago

Hi @hawkinsp

Using jax-metal 0.0.7 with jax 0.4.28 and jaxlib 0.4.28, the mentioned code now encounters an XLARuntimeError instead of segmentation fault: 11.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported.
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()
<unknown>:0: error: failed to legalize operation 'func.func'
<unknown>:0: note: see current operation: 
"func.func"() <{arg_attrs = [{mhlo.layout_mode = "default"}, {mhlo.layout_mode = "default"}], function_type = (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<complex<f32>>, %arg1: tensor<complex<f32>>):
  %0 = "mhlo.add"(%arg0, %arg1) : (tensor<complex<f32>>, tensor<complex<f32>>) -> tensor<complex<f32>>
  "func.return"(%0) : (tensor<complex<f32>>) -> ()
}) : () -> ()

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

jax.print_environment_info():

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

Thank you.

hawkinsp commented 6 months ago

Thanks @rajasekharporeddy . I think we can now redesignate this as an "enhancement", not a "bug". It's a bug if the plugin crashes or produces wrong output. It's now just a missing feature.

akelman commented 3 months ago

Chiming in to add that I would also greatly appreciate support for complex numbers. I also do lattice QCD (physics simulations), most of which runs on clusters, but local testing would be a huge plus. In my use case, FP64 support is crucial.

benjaminpope commented 2 months ago

It's essential for physics & ML-in-physics applications to have complex support at both complex64 and complex128 levels. What is the timescale on this being implemented in metal?

phansel commented 1 month ago

https://github.com/rafael-fuente/diffractsim/ uses np.float (float64) and np.complex (complex128) dtypes, neither of which are supported for the JAX GPU backend via jax-metal.

Even without jax-metal, setting the backend to jax.numpy still provides a substantial speedup over numpy.

benjaminpope commented 1 month ago

Yep - anything involving optics or wave physics more or less requires complex dtypes.

smorad commented 3 weeks ago

Just chiming in to note that popular linear complexity transformer alternatives like Mamba and LRU also use complex numbers.