Open youurayy opened 8 months ago
Can't even try with jax-metal version 0.0.4 as recommented here: https://developer.apple.com/metal/jax/ Because pip says:
jax-metal 0.0.4 depends on jaxlib==0.4.11
and
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.11 (from versions: 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25)
also not possible to go for the oldest available jaxlib 0.4.17:
The conflict is caused by: The user requested jaxlib==0.4.17 jax-metal 0.0.4 depends on jaxlib==0.4.11
So we're in a bit of a pickle here as any longer/heavier training on the MPS will slow down dramatically as it swaps gigabytes of allocated (leaked) memory.
jax-metal==0.0.5 with (its pip dependency) jaxlib==0.4.20 (which pulls jax-0.4.20), on MLIR 1.0.0 (Sonoma 14.2.1) -- also does exhibit the leak.
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
796,737,536
841,973,760
879,001,600
913,145,856
946,585,600
978,878,464
Hi guys, would there be any outlook on this? It's a show stopper for me, and if nobody is free to have a look, maybe I could.
What is the chance this is a bug in jax-metal? (I don't think Apple's jax-metal sources are public.)
Last question - should this be raised in https://developer.apple.com/forums/tags/tensorflow-metal? (This page points to it.)
This is reproducible and we will take a look.
For anyone else having the same issue and being blocked on MPS due to it, have a look at the MLX framework which is similar to JAX and tailored specifically to Apple Silicon.
@youurayy it is similar but also misses so much. have you found a good solution to MLX profiling?
I am also experiencing the same issue. My program starts with 12 GB of memory usage, but after a few epochs, it increases to around 90 GB. At this point, GPU usage drops to 0%, and my code stops training altogether. Any updates or potential workarounds for this problem would be greatly appreciated.
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! Metal device set to: Apple M3 Pro
systemMemory: 18.00 GB maxCacheSize: 6.00 GB
jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.0 python: 3.11.7 (v3.11.7:fa7a6f2303, Dec 4 2023, 15:22:56) [Clang 13.0.0 (clang-1300.0.29.30)] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Atas-MacBook-Pro.fritz.box', release='24.0.0', version='Darwin Kernel Version 24.0.0: Tue Sep 24 23:37:25 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6030', machine='arm64')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1728661436.923585 7587819 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
I0000 00:00:1728661436.937859 7587819 service.cc:145] XLA service 0x1143d2270 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1728661436.937877 7587819 service.cc:153] StreamExecutor device (0): Metal,
Description
Simply calling a @jit function with any kind of array input leaks memory in an non-insignificant way. CPU is okay, but device (MPS) exhibits the leak.
Sample program:
Output on CPU:
Output on MPS:
System info (python version, jaxlib version, accelerator, etc.)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-03-18 18:17:36.186791: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M3 Max
systemMemory: 128.00 GB maxCacheSize: 48.00 GB
jax: 0.4.25 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Jurajs-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')
jax-metal 0.0.6