Open alucantonio opened 1 year ago
Is it specific to LBFGS or does it happen with any solver?
Hi, I have experienced the issue with LBFGS and GradientDescent. The increase in memory is less evident with GradientDescent, but it is still there. I believe the issue does not depend on the solver.
Can you also check if the jit
and unroll
options to LBFGS
have any impact on this? Depending on these options, a different loop implementation is used.
Normally, solver objects don't store anything, so I'm not sure where this could come from...
CC @fabianp, @froystig
Setting jit=False
and unroll=True
or jit=True
and unroll=False
and using LBFGS still produces an increase in memory after each call of solver.run
.
I don't have a solution for this, but can confirm that it happens also with the update
API, i.e., when updates are run inside a for loop:
def optimize(min):
def obj(x, min):
return jnp.square(x-min).sum()
x0 = jnp.zeros(1)
mm = jnp.array(min)
solver = jaxopt.LBFGS(obj, implicit_diff=False, maxiter=100)
state = solver.init_state(x0, min=mm)
jitted_update = jax.jit(solver.update)
params = x0
for _ in range(solver.maxiter):
params, state = jitted_update(params, state, min=mm)
Some updates on my investigations.
import jax.numpy as jnp
import jaxopt
import jax
import gc
import time
def optimize(min):
def obj(x, min):
return jnp.square(x-min).sum()
x0 = jnp.zeros(1)
mm = jnp.array(min)
solver = jaxopt.LBFGS(obj)
state = solver.init_state(x0, min=mm)
jitted_update = jax.jit(solver.update)
params = x0
for _ in range(solver.maxiter):
pass
# params, state = jitted_update(params, state, min=mm)
time.sleep(1)
for i in range(10): optimize(i) gc.collect()
However, if I uncomment the lines inside the for loop (even for just 1 iteration), the leak comes back
In your example, is there still a leak if the update is not jitted?
yeah, although there's a small decrease at the end that could mean it's recuperating some memory.
This is without jitting:
and with jitting:
As you can see, it's also using a lot more memory when it's not jitting. Not sure what to make of that
Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.
Yes to both. Seems like a bug and should be fixed (although we're all spread too thin, I wouldn't know how to set a timeline on it)
On Thu, Feb 23, 2023, 09:49 Alessandro Lucantonio @.***> wrote:
Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.
— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1441389521, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB3HYXKRYPU5PVS3UVTWY4QAHANCNFSM6AAAAAAT62H7U4 . You are receiving this because you were mentioned.Message ID: @.***>
Hi, has been there any progress on this issue?
This behavior can be avoided using the newly implemented jax.clear_caches()
in jax (thanks @froystig !).
For example, the code below doesn't have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:
import jax.numpy as jnp
import jaxopt
import jax
import gc
import time
def optimize(min):
def obj(x, min):
return jnp.square(x-min).sum()
x0 = jnp.zeros(1)
mm = jnp.array(min)
solver = jaxopt.LBFGS(obj, maxiter=100)
x = solver.run(x0, min=mm).params[0]
print(x)
for i in range(10):
optimize(i)
jax.clear_caches()
I'm going to close the issue for now, but please reopen if problem persist (BTW you might need the development version of jax for the clear_caches() function)
It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?
Maybe, but at this point it seems more of an issue concerning jax than jaxopt, wdyt?
On Fri, May 19, 2023, 11:30 Mathieu Blondel @.***> wrote:
It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?
— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1554297577, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB5LS6WPDJW3J2VE6BLXG44RZANCNFSM6AAAAAAT62H7U4 . You are receiving this because you modified the open/close state.Message ID: @.***>
Agreed!
@froystig made a good point in private conversation, that this might be symptomatic of jaxopt not using the cache properly and/or generating too many fresh functions instead of re-using the cache.
I don't have the bandwidth to look into it right now, but leaving open in case someone can look into it more deeply
I am trying to solve a problem where
solver.run
is called multiple times to minimize a series of functions while varying a parameter. Usingmemory_profiler
I can see that the allocated memory increases each time the functionsolver.run
is called and never decreases.Here is a minimal example to reproduce the issue:
And here is the corresponding plot of the allocated memory:
Can you please confirm the issue or provide a solution for that? Thanks. Alessandro