google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

Possible memory leak when calling solver.run multiple times #380

Open alucantonio opened 1 year ago

alucantonio commented 1 year ago

I am trying to solve a problem where solver.run is called multiple times to minimize a series of functions while varying a parameter. Using memory_profiler I can see that the allocated memory increases each time the function solver.run is called and never decreases.

Here is a minimal example to reproduce the issue:

import jax.numpy as jnp
import jaxopt
from memory_profiler import profile

@profile
def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)

And here is the corresponding plot of the allocated memory: Figure_1

Can you please confirm the issue or provide a solution for that? Thanks. Alessandro

mblondel commented 1 year ago

Is it specific to LBFGS or does it happen with any solver?

alucantonio commented 1 year ago

Hi, I have experienced the issue with LBFGS and GradientDescent. The increase in memory is less evident with GradientDescent, but it is still there. I believe the issue does not depend on the solver.

mblondel commented 1 year ago

Can you also check if the jit and unroll options to LBFGS have any impact on this? Depending on these options, a different loop implementation is used.

Normally, solver objects don't store anything, so I'm not sure where this could come from...

CC @fabianp, @froystig

alucantonio commented 1 year ago

Setting jit=False and unroll=True or jit=True and unroll=False and using LBFGS still produces an increase in memory after each call of solver.run.

fabianp commented 1 year ago

I don't have a solution for this, but can confirm that it happens also with the update API, i.e., when updates are run inside a for loop:

def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, implicit_diff=False, maxiter=100)
    state = solver.init_state(x0, min=mm)
    jitted_update = jax.jit(solver.update)
    params = x0
    for _ in range(solver.maxiter):
        params, state = jitted_update(params, state, min=mm)
fabianp commented 1 year ago

Some updates on my investigations.

  1. Upon @mblondel's idea, I set eq=True in the definition of LBFGS. It didn't help.
  2. I also modified the LBFGS class to remote the dataclass decorator. It didn't help.
  3. I'm inclined to think the issue is in the update method. The following code that constructs the solver but doesn't perform the updates doesn't have the memory leak:
    
    import jax.numpy as jnp
    import jaxopt
    import jax
    import gc
    import time

def optimize(min):

def obj(x, min):
    return jnp.square(x-min).sum()

x0 = jnp.zeros(1)
mm = jnp.array(min)

solver = jaxopt.LBFGS(obj)
state = solver.init_state(x0, min=mm)
jitted_update = jax.jit(solver.update)
params = x0
for _ in range(solver.maxiter):
    pass
#     params, state = jitted_update(params, state, min=mm)
time.sleep(1)

for i in range(10): optimize(i) gc.collect()



However, if I uncomment the lines inside the for loop (even for just 1 iteration), the leak comes back
froystig commented 1 year ago

In your example, is there still a leak if the update is not jitted?

fabianp commented 1 year ago

yeah, although there's a small decrease at the end that could mean it's recuperating some memory.

This is without jitting: image

and with jitting: image

As you can see, it's also using a lot more memory when it's not jitting. Not sure what to make of that

alucantonio commented 1 year ago

Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.

fabianp commented 1 year ago

Yes to both. Seems like a bug and should be fixed (although we're all spread too thin, I wouldn't know how to set a timeline on it)

On Thu, Feb 23, 2023, 09:49 Alessandro Lucantonio @.***> wrote:

Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.

— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1441389521, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB3HYXKRYPU5PVS3UVTWY4QAHANCNFSM6AAAAAAT62H7U4 . You are receiving this because you were mentioned.Message ID: @.***>

alucantonio commented 1 year ago

Hi, has been there any progress on this issue?

fabianp commented 1 year ago

This behavior can be avoided using the newly implemented jax.clear_caches() in jax (thanks @froystig !).

For example, the code below doesn't have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:

Figure_1

import jax.numpy as jnp
import jaxopt
import jax
import gc
import time

def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)
    jax.clear_caches()
fabianp commented 1 year ago

I'm going to close the issue for now, but please reopen if problem persist (BTW you might need the development version of jax for the clear_caches() function)

mblondel commented 1 year ago

It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?

fabianp commented 1 year ago

Maybe, but at this point it seems more of an issue concerning jax than jaxopt, wdyt?

On Fri, May 19, 2023, 11:30 Mathieu Blondel @.***> wrote:

It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?

— Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/380#issuecomment-1554297577, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACDZB5LS6WPDJW3J2VE6BLXG44RZANCNFSM6AAAAAAT62H7U4 . You are receiving this because you modified the open/close state.Message ID: @.***>

mblondel commented 1 year ago

Agreed!

fabianp commented 1 year ago

@froystig made a good point in private conversation, that this might be symptomatic of jaxopt not using the cache properly and/or generating too many fresh functions instead of re-using the cache.

I don't have the bandwidth to look into it right now, but leaving open in case someone can look into it more deeply