jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

jax-metal: Memory leak on jit boundary, MPS #20296

Open youurayy opened 7 months ago

youurayy commented 7 months ago

Description

Simply calling a @jit function with any kind of array input leaks memory in an non-insignificant way. CPU is okay, but device (MPS) exhibits the leak.

Sample program:

import os
import psutil

import jax
import jax.numpy as jnp

# uncomment to test on CPU:
# jax.config.update('jax_platform_name', 'cpu')

key = jax.random.PRNGKey(42)
array = jax.random.uniform(key, shape=(39325, 173), dtype=jnp.float32)

@jax.jit
def jax_func(arr):
    pass

ps = psutil.Process(os.getpid())

for i in range(1000000000):

    if i % 1000000 == 0:
        print(f'{ps.memory_info().rss:,}')

    jax_func(array)

Output on CPU:

241,319,936
243,138,560
243,138,560
243,138,560
243,138,560
243,138,560
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176
243,122,176

Output on MPS:

795,508,736
857,686,016
890,961,920
924,090,368
956,465,152
988,774,400
1,021,132,800
1,053,589,504
1,085,865,984
1,118,224,384
1,150,517,248
1,182,744,576
1,214,988,288
1,247,281,152
1,279,557,632
1,311,801,344
1,344,045,056
1,376,305,152

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-03-18 18:17:36.186791: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M3 Max

systemMemory: 128.00 GB maxCacheSize: 48.00 GB

jax: 0.4.25 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Jurajs-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6031', machine='arm64')

jax-metal 0.0.6

youurayy commented 7 months ago

Can't even try with jax-metal version 0.0.4 as recommented here: https://developer.apple.com/metal/jax/ Because pip says:

jax-metal 0.0.4 depends on jaxlib==0.4.11

and

ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.11 (from versions: 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25)

also not possible to go for the oldest available jaxlib 0.4.17:

The conflict is caused by: The user requested jaxlib==0.4.17 jax-metal 0.0.4 depends on jaxlib==0.4.11

So we're in a bit of a pickle here as any longer/heavier training on the MPS will slow down dramatically as it swaps gigabytes of allocated (leaked) memory.

youurayy commented 7 months ago

jax-metal==0.0.5 with (its pip dependency) jaxlib==0.4.20 (which pulls jax-0.4.20), on MLIR 1.0.0 (Sonoma 14.2.1) -- also does exhibit the leak.

Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

796,737,536
841,973,760
879,001,600
913,145,856
946,585,600
978,878,464
youurayy commented 6 months ago

Hi guys, would there be any outlook on this? It's a show stopper for me, and if nobody is free to have a look, maybe I could.

What is the chance this is a bug in jax-metal? (I don't think Apple's jax-metal sources are public.)

Last question - should this be raised in https://developer.apple.com/forums/tags/tensorflow-metal? (This page points to it.)

shuhand0 commented 6 months ago

This is reproducible and we will take a look.

youurayy commented 5 months ago

For anyone else having the same issue and being blocked on MPS due to it, have a look at the MLX framework which is similar to JAX and tailored specifically to Apple Silicon.

alexlatif commented 2 weeks ago

@youurayy it is similar but also misses so much. have you found a good solution to MLX profiling?

Ata-Shaker commented 2 weeks ago

I am also experiencing the same issue. My program starts with 12 GB of memory usage, but after a few epochs, it increases to around 90 GB. At this point, GPU usage drops to 0%, and my code stops training altogether. Any updates or potential workarounds for this problem would be greatly appreciated.

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! Metal device set to: Apple M3 Pro

systemMemory: 18.00 GB maxCacheSize: 6.00 GB

jax: 0.4.26 jaxlib: 0.4.26 numpy: 1.26.0 python: 3.11.7 (v3.11.7:fa7a6f2303, Dec 4 2023, 15:22:56) [Clang 13.0.0 (clang-1300.0.29.30)] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='Atas-MacBook-Pro.fritz.box', release='24.0.0', version='Darwin Kernel Version 24.0.0: Tue Sep 24 23:37:25 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6030', machine='arm64')

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR W0000 00:00:1728661436.923585 7587819 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! I0000 00:00:1728661436.937859 7587819 service.cc:145] XLA service 0x1143d2270 initialized for platform METAL (this does not guarantee that XLA will be used). Devices: I0000 00:00:1728661436.937877 7587819 service.cc:153] StreamExecutor device (0): Metal, I0000 00:00:1728661436.939592 7587819 mps_client.cc:406] Using Simple allocator. I0000 00:00:1728661436.939615 7587819 mps_client.cc:384] XLA backend will use up to 12884443136 bytes on device 0 for SimpleAllocator.