Closed ctiennot closed 5 months ago
I think this is expected. I guess you can release memory with
jax.clear_caches()
for x in jax.live_arrays():
x.delete()
See https://jax.readthedocs.io/en/latest/jax.html#miscellaneous
Thanks for your answers, unfortunately I don't think it fixes it entirely, if I just add the two lines in the for loop:
from collections import Counter
import numpy as np
import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
import psutil
process = psutil.Process(os.getpid())
print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)
def model(y_obs):
mu = sample('mu', dist.Normal(0., 1.))
sigma = sample("sigma", dist.HalfCauchy(3.))
y = sample("y", dist.Normal(mu, sigma), obs=y_obs)
for i in range(10):
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True)
mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),)))
print("\nGC OBJECTS:")
cnt = Counter()
# force collection; it is expected that count of different types
# should not increase per iteration
gc.collect()
for x in gc.get_objects():
if isinstance(x, list):
if len(x) > 1:
cnt[type(x[0])] += 1
print(cnt.most_common(10))
memory_in_mb = process.memory_info().rss / 1024 / 1024
print(i, memory_in_mb)
jax.clear_caches()
for x in jax.live_arrays():
x.delete()
I still get increasing memory usage:
jax version: 0.4.23
sample: 100%|██████████| 3/3 [00:01<00:00, 1.59it/s, 3 steps of size 6.20e-01. acc. prob=0.67]
GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1221), (<class 'tuple'>, 796), (<class 'method'>, 214), (<class 'jax._src.core.JaxprEqn'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 49), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
0 301.45703125
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.34]
GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1380), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 226), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 57), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
1 304.81640625
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]
GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1537), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 238), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 65), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
2 305.65234375
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]
GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1696), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 250), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 73), (<class 'int'>, 66), (<class 'jax._src.layout.SpecifiedLayout'>, 49)]
3 305.33203125
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.52]
GC OBJECTS:
[(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1857), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 262), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 81), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 55)]
4 307.87890625
sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2016), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 274), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 89), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 66)]
5 308.859375
sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2175), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 286), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 97), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 77), (<class 'int'>, 66)]
6 309.73828125
sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2334), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 298), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 105), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 88), (<class 'int'>, 66)]
7 309.390625
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.66]
GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2493), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 310), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 113), (<class 'jax._src.core.ClosedJaxpr'>, 99), (<class 'bool'>, 90), (<class 'int'>, 66)]
8 311.8359375
sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.39]
GC OBJECTS:
[(<class 'jax._src.core.Var'>, 2652), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 322), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 121), (<class 'jax._src.core.ClosedJaxpr'>, 110), (<class 'bool'>, 90), (<class 'int'>, 66)]
9 312.9375
there're jax objects but they might not store data. Could you double check the memory?
On Tue, Dec 19, 2023, 4:04 AM Clément T. @.***> wrote:
Thanks for your answers, unfortunately I don't think it fixes it entirely, if I just add the two lines in the for loop:
from collections import Counter
import numpy as np
import numpyro, jax from numpyro import sample import numpyro.distributions as dist from jax import random, vmap from numpyro.infer import MCMC, NUTS import os import psutil
process = psutil.Process(os.getpid())
print("numpyro version: ", numpyro.version) print("jax version: ", jax.version)
def model(y_obs): mu = sample('mu', dist.Normal(0., 1.)) sigma = sample("sigma", dist.HalfCauchy(3.)) y = sample("y", dist.Normal(mu, sigma), obs=y_obs)
for i in range(10): kernel = NUTS(model) mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True) mcmc.run(random.PRNGKey(0), np.zeros((np.random.randint(0, 10),))) print("\nGC OBJECTS:") cnt = Counter()
force collection; it is expected that count of different types
# should not increase per iteration gc.collect() for x in gc.get_objects(): if isinstance(x, list): if len(x) > 1: cnt[type(x[0])] += 1 print(cnt.most_common(10)) memory_in_mb = process.memory_info().rss / 1024 / 1024 print(i, memory_in_mb) jax.clear_caches() for x in jax.live_arrays(): x.delete()
I still get increasing memory usage:
jax version: 0.4.23 sample: 100%|██████████| 3/3 [00:01<00:00, 1.59it/s, 3 steps of size 6.20e-01. acc. prob=0.67]
GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1221), (<class 'tuple'>, 796), (<class 'method'>, 214), (<class 'jax._src.core.JaxprEqn'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 49), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 0 301.45703125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.34]
GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1380), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 226), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 57), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 1 304.81640625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]
GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1537), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 238), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'int'>, 66), (<class 'jax._src.core.Literal'>, 65), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 2 305.65234375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 1 steps of size 1.72e+00. acc. prob=0.17]
GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1696), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 250), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 73), (<class 'int'>, 66), (<class 'jax._src.layout.SpecifiedLayout'>, 49)] 3 305.33203125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.52]
GC OBJECTS: [(<class 'str'>, 1957), (<class 'jax._src.core.Var'>, 1857), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 262), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 81), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 55)] 4 307.87890625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS: [(<class 'jax._src.core.Var'>, 2016), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 274), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'bool'>, 90), (<class 'jax._src.core.Literal'>, 89), (<class 'int'>, 66), (<class 'jax._src.core.ClosedJaxpr'>, 66)] 5 308.859375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS: [(<class 'jax._src.core.Var'>, 2175), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 286), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 97), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 77), (<class 'int'>, 66)] 6 309.73828125 sample: 100%|██████████| 3/3 [00:01<00:00, 1.67it/s, 1 steps of size 5.00e+00. acc. prob=0.00]
GC OBJECTS: [(<class 'jax._src.core.Var'>, 2334), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 298), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 105), (<class 'bool'>, 90), (<class 'jax._src.core.ClosedJaxpr'>, 88), (<class 'int'>, 66)] 7 309.390625 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.66]
GC OBJECTS: [(<class 'jax._src.core.Var'>, 2493), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 310), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 113), (<class 'jax._src.core.ClosedJaxpr'>, 99), (<class 'bool'>, 90), (<class 'int'>, 66)] 8 311.8359375 sample: 100%|██████████| 3/3 [00:01<00:00, 1.69it/s, 3 steps of size 5.89e-01. acc. prob=0.39]
GC OBJECTS: [(<class 'jax._src.core.Var'>, 2652), (<class 'str'>, 1957), (<class 'tuple'>, 796), (<class 'jax._src.core.JaxprEqn'>, 322), (<class 'method'>, 214), (<class 'traitlets.traitlets.ObserveHandler'>, 208), (<class 'jax._src.core.Literal'>, 121), (<class 'jax._src.core.ClosedJaxpr'>, 110), (<class 'bool'>, 90), (<class 'int'>, 66)] 9 312.9375
— Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/numpyro/issues/1699#issuecomment-1862373154, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABEEKVS5DHCVS7BG4W7DDDTYKFKC7AVCNFSM6AAAAABAZTMBQWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRSGM3TGMJVGQ . You are receiving this because you commented.Message ID: @.***>
there're jax objects but they might not store data. Could you double check the memory?
The memory does increase (it's already in the output, going from 301.4 mb to 312 mb over 10 runs...
maybe numpyro also caches something. I'll check later.
maybe numpyro also caches something. I'll check later.
Thanks, FYI I think this is similar to https://forum.pyro.ai/t/memory-leak-for-hierarchical-mixture-model/3807 which was not solved 😞
Did you get OOM with the above solution? The increase in memory seems small to me.
Did you get OOM with the above solution? The increase in memory seems small to me.
The issue is that I want to serve this model using ml-server and this means that memory will slowly increase and make it fail over time which is not acceptable. The increase is small but the computations will be triggered at each call to the API so this can accumulate quickly...
Do you think that memory will increase forever? There are some ways you can serve a jax program with dynamic shapes. One of them is to mask the data in to a 2**n array (where n varies). This way the number of compiled program would be small. Users can avoid recompiling their programs if n is matched.
I'm not sure how actionable this is but I used the following code:
import numpy as np
import numpyro, jax
from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS
import os
jax.config.update("jax_platform_name", "cpu")
print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)
def model(y_obs):
mu = sample("mu", dist.Normal(0., 1.))
sigma = sample("sigma", dist.HalfCauchy(3.))
y = sample("y", dist.Normal(mu, sigma), obs=y_obs)
for i in range(100):
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1, num_warmup=2, jit_model_args=True, progress_bar=False)
mcmc.run(random.PRNGKey(0), np.zeros((1,)))
jax.clear_caches()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
and then used the --diff_base
feature of pprof
to visualize the change in memory usage between iterations:
i==0
and i=9
i==0
and i=99
I think the issue happens with dynamic-shape inputs. It's not clear to me that this is a numpyro issue. @ctiennot What happens if you remove numpyro stuff from your code?
Hey, sorry for the late reply:
Do you think that memory will increase forever? There are some ways you can serve a jax program with dynamic shapes. One of them is to mask the data in to a 2**n array (where n varies). This way the number of compiled program would be small. Users can avoid recompiling their programs if n is matched.
I tried it without success but maybe I'm not doing it right... Also I think it's not super convenient to have to pad the inputs. Eventually I am know using another MCMC implementation to carry the inference and avoid Numpyro memory issue 😢 If I have time I'll give it another try.
I think the issue happens with dynamic-shape inputs. It's not clear to me that this is a numpyro issue. @ctiennot What happens if you remove numpyro stuff from your code?
Hm what do you mean removing numpyro stuff? I'm not sure I get it 🤔
Also thanks for the profiling graph @tare :)
@fehiepsi I did run a couple of experiments without NumPyro to see the effects of jax.jit()
, jax.clear_caches()
, and dynamic shapes. Did you mean something like that?
These results suggest that the issue is not due to NumPyro. However, I don't know know the internals of JAX enough to explain what is happening here.
jax.jit()
, w/out jax.clear_caches()
, and static shapesjax.jit()
and w/ jax.clear_caches()
, and static shapesjax.jit()
, w/out jax.clear_caches()
, and dynamic shapesjax.jit()
, w/ jax.clear_caches()
, and dynamic shapesjax.jit()
, w/out jax.clear_caches()
, and static shapesjax.jit()
, w/ jax.clear_caches()
, and static shapesjax.jit()
, w/out jax.clear_caches()
, and dynamic shapesjax.jit()
, w/ jax.clear_caches()
, and dynamic shapesHuge thanks, @tare! I think we can make a jax issue with your reproducible code.
The issue seems to be explained upstream. Let's close this, I guess.
Hi there,
I have an issue: I'm trying to serve a Numpyro model using mlflow and mlserver: my model has varying input sizes and needs to reestimate all parameters regularly. I'm using MCMC to estimate those but the memory linearly increases each time it is called.
I've seen https://github.com/pyro-ppl/numpyro/issues/1347 but it doesn't fix it in my case (as the size of the inputs is changing). Is this a bug? Is there a way to force releasing memory manually? Thanks!
numpyro version: 0.13.2 jax version: 0.4.23
Code to reproduce (taken from https://github.com/pyro-ppl/numpyro/issues/1347).