jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

jax-metal slowdown on m3 max when external monitor is plugged in #24163

Open bjeurissen opened 1 month ago

bjeurissen commented 1 month ago

Description

I am using macOS Sequoia 15.0.1 (24A348) on an Apple M3 Max.

python: 3.12.7 jax: 0.4.34 jaxlib: 0.4.34 jax-metal: 0.1.0

I noticed that I am consistently getting a reduction in speed by more than a factor of 2 on a simple toy problem when I attach an external monitor.

E.g. without an external monitor I get:

10000/10000 [00:09<00:00, 1021.68it/s]

With an external monitor, I get:

10000/10000 [00:24<00:00, 407.24it/s]

My suspicion was that refresh rate of the monitor could have something to do with it, so I decided to unplug the monitor and test with the internal display, but this time reducing the refresh rate from 120 Hz (default) to 60 Hz in the Display settings of macOS and again I got a reduction in speed (although slightly less than with an external monitor plugged in):

10000/10000 [00:18<00:00, 549.69it/s]

Strange thing is that when I fiddle a lot with the display controls (switching between mirroring and extended screen), I can sometimes have it run at 1000it/s even with the external display plugged in. This to me suggests that this is probably a macOS video/mps driver issue and not specific to jax-metal. I could not reproduce this problem with any other libraries for numerical calculations that support MPS though, although none of those managed to get more than 400it/s to begin with, even without an external monitor attached.

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1728313982.685485  115532 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

I0000 00:00:1728313982.692505  115532 service.cc:145] XLA service 0x600003911200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1728313982.692514  115532 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1728313982.693547  115532 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1728313982.693555  115532 mps_client.cc:384] XLA backend will use up to 103078739968 bytes on device 0 for SimpleAllocator.
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 15:57:01) [Clang 17.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='<REMOVED>', release='24.0.0', version='Darwin Kernel Version 24.0.0: Tue Sep 24 23:35:10 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6031', machine='arm64')
hawkinsp commented 1 month ago

I also suspect this is a Mac OS issue and nothing we can address from JAX, but @shuhand0 can determine that.