pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.14k stars 235 forks source link

Memory leak in MCMC? #1699

Closed ctiennot closed 4 months ago

ctiennot commented 9 months ago

Hi there,

I have an issue: I'm trying to serve a Numpyro model using mlflow and mlserver: my model has varying input sizes and needs to reestimate all parameters regularly. I'm using MCMC to estimate those but the memory linearly increases each time it is called.

I've seen https://github.com/pyro-ppl/numpyro/issues/1347 but it doesn't fix it in my case (as the size of the inputs is changing). Is this a bug? Is there a way to force releasing memory manually? Thanks!

numpyro version: 0.13.2 jax version: 0.4.23

Code to reproduce (taken from https://github.com/pyro-ppl/numpyro/issues/1347).

from collections import Counter

import numpy as np

import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
import psutil

process = psutil.Process(os.getpid())

print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)

def model(y_obs):
    mu = sample('mu', dist.Normal(0., 1.))
    sigma = sample("sigma", dist.HalfCauchy(3.))
    y = sample("y", dist.Normal(mu, sigma), obs=y_obs)

for i in range(10):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True)
    mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),)))
    print("\nGC OBJECTS:")
    cnt = Counter()
    # force collection; it is expected that count of different types
    # should not increase per iteration
    gc.collect()
    for x in gc.get_objects():
        if isinstance(x, list):
            if len(x) > 1:
                cnt[type(x[0])] += 1
    print(cnt.most_common(10))

    memory_in_mb = process.memory_info().rss / 1024 / 1024

    print(i, memory_in_mb)
tare commented 9 months ago

This seems to solve the issue.

fehiepsi commented 9 months ago

I think this is expected. I guess you can release memory with

jax.clear_caches()
for x in jax.live_arrays():
    x.delete()

See https://jax.readthedocs.io/en/latest/jax.html#miscellaneous

ctiennot commented 9 months ago

Thanks for your answers, unfortunately I don't think it fixes it entirely, if I just add the two lines in the for loop:

from collections import Counter

import numpy as np

import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
import psutil

process = psutil.Process(os.getpid())

print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)

def model(y_obs):
    mu = sample('mu', dist.Normal(0., 1.))
    sigma = sample("sigma", dist.HalfCauchy(3.))
    y = sample("y", dist.Normal(mu, sigma), obs=y_obs)

for i in range(10):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True)
    mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),)))
    print("\nGC OBJECTS:")
    cnt = Counter()
    # force collection; it is expected that count of different types
    # should not increase per iteration
    gc.collect()
    for x in gc.get_objects():
        if isinstance(x, list):
            if len(x) > 1:
                cnt[type(x[0])] += 1
    print(cnt.most_common(10))

    memory_in_mb = process.memory_info().rss / 1024 / 1024

    print(i, memory_in_mb)

    jax.clear_caches()
    for x in jax.live_arrays():
        x.delete()

I still get increasing memory usage:

jax version:  0.4.23
sample: 100%|██████████| 3/3 [00:01<00:00,  1.59it/s, 3 steps of size 6.20e-01. acc. prob=0.67]

GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1221), (<class 'tuple'>, 796), (<class 'method'>, 214), (<class 'jax._src.core.JaxprEqn'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 49), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
0 301.45703125
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.34]

GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1380), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 226), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 57), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
1 304.81640625
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]

GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1537), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 238), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 65), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
2 305.65234375
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]

GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1696), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 250), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 73), (<class 'int'>, 66), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
3 305.33203125
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.52]

GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1857), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 262), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 81), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 55)]
4 307.87890625
sample: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2016), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 274), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 89), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 66)]
5 308.859375
sample: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2175), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 286), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 97), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 77), (<class 'int'>, 66)]
6 309.73828125
sample: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2334), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 298), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 105), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 88), (<class 'int'>, 66)]
7 309.390625
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.66]

GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2493), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 310), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 113), (<class 'jax._src.core.ClosedJaxpr'>, 99), (<class 'bool'>, 90), (<class 'int'>, 66)]
8 311.8359375
sample: 100%|██████████| 3/3 [00:01<00:00,  1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.39]

GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2652), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 322), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 121), (<class 'jax._src.core.ClosedJaxpr'>, 110), (<class 'bool'>, 90), (<class 'int'>, 66)]
9 312.9375
fehiepsi commented 9 months ago

there're jax objects but they might not store data. Could you double check the memory?

On Tue, Dec 19, 2023, 4:04 AM Clément T. @.***> wrote:

Thanks for your answers, unfortunately I don't think it fixes it entirely, if I just add the two lines in the for loop:

from collections import Counter

import numpy as np

import numpyro, jax from numpyro import sample import numpyro.distributions as dist from jax import random, vmap from numpyro.infer import MCMC, NUTS import os import psutil

process = psutil.Process(os.getpid())

print("numpyro version: ", numpyro.version) print("jax version: ", jax.version)

def model(y_obs): mu = sample('mu', dist.Normal(0., 1.)) sigma = sample("sigma", dist.HalfCauchy(3.)) y = sample("y", dist.Normal(mu, sigma), obs=y_obs)

for i in range(10): kernel = NUTS(model) mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True) mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),))) print("\nGC OBJECTS:") cnt = Counter()

force collection; it is expected that count of different types

# should not increase per iteration
gc.collect()
for x in gc.get_objects():
    if isinstance(x, list):
        if len(x) > 1:
            cnt[type(x[0])] += 1
print(cnt.most_common(10))

memory_in_mb = process.memory_info().rss / 1024 / 1024

print(i, memory_in_mb)

jax.clear_caches()
for x in jax.live_arrays():
    x.delete()

I still get increasing memory usage:

jax version: 0.4.23 sample: 100%|██████████| 3/3 [00:01<00:00, 1.59it/s, 3 steps of size 6.20e-01. acc. prob=0.67]

GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1221), (<class 'tuple'>, 796), (<class 'method'>, 214), (<class 'jax._src.core.JaxprEqn'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 49), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 0 301.45703125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.34]

GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1380), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 226), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 57), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 1 304.81640625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]

GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1537), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 238), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 65), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 2 305.65234375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]

GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1696), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 250), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 73), (<class 'int'>, 66), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 3 305.33203125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.52]

GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1857), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 262), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 81), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 55)] 4 307.87890625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS: [(<class 'jax._src.core.Var'>, 2016), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 274), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 89), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 66)] 5 308.859375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS: [(<class 'jax._src.core.Var'>, 2175), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 286), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 97), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 77), (<class 'int'>, 66)] 6 309.73828125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]

GC OBJECTS: [(<class 'jax._src.core.Var'>, 2334), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 298), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 105), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 88), (<class 'int'>, 66)] 7 309.390625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.66]

GC OBJECTS: [(<class 'jax._src.core.Var'>, 2493), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 310), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 113), (<class 'jax._src.core.ClosedJaxpr'>, 99), (<class 'bool'>, 90), (<class 'int'>, 66)] 8 311.8359375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.39]

GC OBJECTS: [(<class 'jax._src.core.Var'>, 2652), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 322), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 121), (<class 'jax._src.core.ClosedJaxpr'>, 110), (<class 'bool'>, 90), (<class 'int'>, 66)] 9 312.9375

— Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/numpyro/issues/1699#issuecomment-1862373154, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABEEKVS5DHCVS7BG4W7DDDTYKFKC7AVCNFSM6AAAAABAZTMBQWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRSGM3TGMJVGQ . You are receiving this because you commented.Message ID: @.***>

ctiennot commented 9 months ago

there're jax objects but they might not store data. Could you double check the memory?

The memory does increase (it's already in the output, going from 301.4 mb to 312 mb over 10 runs...

fehiepsi commented 9 months ago

maybe numpyro also caches something. I'll check later.

ctiennot commented 9 months ago

maybe numpyro also caches something. I'll check later.

Thanks, FYI I think this is similar to https://forum.pyro.ai/t/memory-leak-for-hierarchical-mixture-model/3807 which was not solved 😞

fehiepsi commented 9 months ago

Did you get OOM with the above solution? The increase in memory seems small to me.

ctiennot commented 9 months ago

Did you get OOM with the above solution? The increase in memory seems small to me.

The issue is that I want to serve this model using ml-server and this means that memory will slowly increase and make it fail over time which is not acceptable. The increase is small but the computations will be triggered at each call to the API so this can accumulate quickly...

fehiepsi commented 9 months ago

Do you think that memory will increase forever? There are some ways you can serve a jax program with dynamic shapes. One of them is to mask the data in to a 2**n array (where n varies). This way the number of compiled program would be small. Users can avoid recompiling their programs if n is matched.

tare commented 9 months ago

I'm not sure how actionable this is but I used the following code:

import numpy as np
import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os

jax.config.update("jax_platform_name", "cpu")

print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)

def model(y_obs):
    mu = sample("mu", dist.Normal(0., 1.))
    sigma = sample("sigma", dist.HalfCauchy(3.))
    y = sample("y", dist.Normal(mu, sigma), obs=y_obs)

for i in range(100):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True, progress_bar=False)
    mcmc.run(random.PRNGKey(0), np.zeros((1,)))

    jax.clear_caches()

    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

and then used the --diff_base feature of pprof to visualize the change in memory usage between iterations:

i==0 and i=9

profile_0_vs_9

i==0 and i=99

profile_0_vs_99

fehiepsi commented 9 months ago

I think the issue happens with dynamic-shape inputs. It's not clear to me that this is a numpyro issue. @ctiennot What happens if you remove numpyro stuff from your code?

ctiennot commented 9 months ago

Hey, sorry for the late reply:

Do you think that memory will increase forever? There are some ways you can serve a jax program with dynamic shapes. One of them is to mask the data in to a 2**n array (where n varies). This way the number of compiled program would be small. Users can avoid recompiling their programs if n is matched.

I tried it without success but maybe I'm not doing it right... Also I think it's not super convenient to have to pad the inputs. Eventually I am know using another MCMC implementation to carry the inference and avoid Numpyro memory issue 😢 If I have time I'll give it another try.

I think the issue happens with dynamic-shape inputs. It's not clear to me that this is a numpyro issue. @ctiennot What happens if you remove numpyro stuff from your code?

Hm what do you mean removing numpyro stuff? I'm not sure I get it 🤔

Also thanks for the profiling graph @tare :)

tare commented 9 months ago

@fehiepsi I did run a couple of experiments without NumPyro to see the effects of jax.jit(), jax.clear_caches(), and dynamic shapes. Did you mean something like that?

These results suggest that the issue is not due to NumPyro. However, I don't know know the internals of JAX enough to explain what is happening here.

w/out jax.jit(), w/out jax.clear_caches(), and static shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(10)) del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 99.2421875 i=100: 99.27734375 i=200: 99.27734375 i=300: 99.27734375 i=400: 99.27734375 i=500: 99.27734375 i=600: 99.27734375 i=700: 99.28125 i=800: 99.28125 i=900: 99.28125 i=1000: 99.28125 ```

w/out jax.jit() and w/ jax.clear_caches(), and static shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(10)) jax.clear_caches() del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 97.3125 i=100: 107.6484375 i=200: 108.23046875 i=300: 108.65625 i=400: 108.69140625 i=500: 108.80078125 i=600: 109.14453125 i=700: 109.15234375 i=800: 109.15625 i=900: 109.1796875 i=1000: 109.18359375 ```

w/out jax.jit(), w/out jax.clear_caches(), and dynamic shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(i)) del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 91.18359375 i=100: 246.37109375 i=200: 390.8203125 i=300: 534.91015625 i=400: 679.5 i=500: 823.40625 i=600: 967.95703125 i=700: 1101.5234375 i=800: 1165.6484375 i=900: 1229.5859375 i=1000: 1291.671875 ```

w/out jax.jit(), w/ jax.clear_caches(), and dynamic shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(i)) jax.clear_caches() del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 91.578125 i=100: 109.5234375 i=200: 111.6171875 i=300: 103.02734375 i=400: 104.6015625 i=500: 105.97265625 i=600: 107.390625 i=700: 108.90234375 i=800: 110.33984375 i=900: 111.76953125 i=1000: 113.30078125 ```

w/ jax.jit(), w/out jax.clear_caches(), and static shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) @jax.jit def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(10)) del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 97.42578125 i=100: 97.4453125 i=200: 97.44921875 i=300: 97.44921875 i=400: 97.44921875 i=500: 97.453125 i=600: 97.453125 i=700: 97.453125 i=800: 97.453125 i=900: 97.453125 i=1000: 97.453125 ```

w/ jax.jit(), w/ jax.clear_caches(), and static shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) @jax.jit def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(10)) jax.clear_caches() del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 97.17578125 i=100: 104.63671875 i=200: 106.3984375 i=300: 106.44140625 i=400: 106.54296875 i=500: 106.5546875 i=600: 106.5625 i=700: 106.58203125 i=800: 106.5859375 i=900: 106.76171875 i=1000: 106.76953125 ```

w/ jax.jit(), w/out jax.clear_caches(), and dynamic shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) @jax.jit def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(i)) del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 90.64453125 i=100: 200.91796875 i=200: 303.6953125 i=300: 403.40234375 i=400: 505.61328125 i=500: 604.90234375 i=600: 705.3515625 i=700: 804.84765625 i=800: 906.359375 i=900: 1005.77734375 i=1000: 1106.8046875 ```

w/ jax.jit(), w/ jax.clear_caches(), and dynamic shapes

Code ```python import jax import jax.numpy as jnp import os import psutil import gc process = psutil.Process(os.getpid()) print("jax version: ", jax.__version__) @jax.jit def fn(x): return jnp.sum(jnp.sin(x)) for i in range(0, 1_001): res = fn(jnp.zeros(i)) jax.clear_caches() del res for x in jax.live_arrays(): x.delete() gc.collect() if i % 100 == 0: print(f"{i=}: {process.memory_info().rss / 1024 / 1024}") ```
Output ``` jax version: 0.4.23 i=0: 90.734375 i=100: 108.7265625 i=200: 111.09375 i=300: 112.77734375 i=400: 114.72265625 i=500: 116.40234375 i=600: 118.37890625 i=700: 119.9921875 i=800: 121.80078125 i=900: 123.40625 i=1000: 125.1953125 ```
fehiepsi commented 9 months ago

Huge thanks, @tare! I think we can make a jax issue with your reproducible code.

fehiepsi commented 4 months ago

The issue seems to be explained upstream. Let's close this, I guess.