google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

Unnecessary recompilation of _while_loop_lax #563

Open hrdl-github opened 6 months ago

hrdl-github commented 6 months ago

ed1febef279de5e1120af70dff04261b61175383 claims that jitting _while_loop_lax is redundant. However, this change also seems to prevent the result from being cached, causing recompilation of loops like https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/base.py#L314 Reverting ed1febef279de5e1120af70dff04261b61175383 drastically reduces compilation times for my use case, so it probably makes sense to address this cache issue.

mblondel commented 6 months ago

CC @froystig

hrdl-github commented 6 months ago

This is mostly relevant when reusing the same solver (I've been using BFGS), as reinstantiation creates new references of _cond_fun and _body_fun, which are static arguments in https://github.com/google/jaxopt/blob/58bac0ac375732dce358cf85583ae7fe3632b8cf/jaxopt/_src/loop.py#L80, I think

froystig commented 6 months ago

Do you have a minimal example that reproduces the slowdown, and would you mind posting it here if so?

hrdl-github commented 6 months ago
import time, jaxopt, jax, jax.numpy as jnp

def rosenbrock(x):
    return jnp.sum(100. * jnp.diff(x) ** 2 + (1. - x[:-1]) ** 2)

solver = jaxopt.BFGS(rosenbrock)
x0 = jnp.zeros(2)

_time = time.time()
sol = solver.run(x0)
_time = time.time() - _time
print(f'Total {_time} s')

jax.config.update('jax_log_compiles', True)

_time = time.time()
sol2 = solver.run(x0)
_time = time.time() - _time
print(f'Total2 {_time} s')

With jax.jit:

Total 1.3411040306091309 s
[1. 1.]
Total2 0.007392406463623047 s

Original library (jax 0.4.23, jaxopt 0.8.2):

Total 1.3537604808807373 s
Finished tracing + transforming while for pjit in 0.0003256797790527344 sec
Compiling while for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(float32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[2]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[2,2]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
Finished jaxpr to MLIR module conversion jit(while) in 0.1349470615386963 sec
Finished XLA compilation of jit(while) in 0.3022458553314209 sec
[1. 1.]
Total2 0.45375514030456543 s
hrdl-github commented 6 months ago

More generally, what do you think about jaxopt caching all solvers, so recompilation would be reduced automatically when not using nested functions?

mblondel commented 6 months ago

what do you propose more concretely?

hrdl-github commented 6 months ago

At the moment I essentially use jaxopt with https://github.com/google/jaxopt/commit/ed1febef279de5e1120af70dff04261b61175383 reverted and make sure that I cache my solvers:

@lru_cache
def get_solver(solver, *args, **kwargs):
    return solver(*args, **kwargs)

My suggestion would be to

  1. Implement caching of _while_loop_lax cleanly without relying on jax.jit -- I haven't dug deep enough into jax yet to know the best way to do this, and
  2. Make it easy for the user to reuse solvers or at least document that reusing solvers will reduce / avoid recompilation to benefit from this change. Another topic would be advising against nested functions, which I've seen in a lot of non-official examples.
NeilGirdhar commented 6 months ago

Would it be possible to detect that you're inside a jitted context rather than accepting the Boolean jit parameter? That way, only the user would ever call jit, and would totally control caching and compilation.