Closed Joshuaalbert closed 3 weeks ago
Absolutely! So what's going on here is that if we have a function that looks something like this:
def nonlinear_solve(fn, ...):
... = fn(...)
... = fn(...)
...
then JAX will trace and compile the user-provided fn
twice. This approximately doubles the compilation time! To emphasize: this is about the speed of compiling, not the speed of runtime.
The thing is that a lot of scientific programs can get very large indeed. Imagine optimizing the parameters of a stiff two-point boundary value problem: this is done by performing
(a) gradient descent over a
(b) root-finding problem, whose fn
is a
(c) differential equation solve, which on each step
(d) involves finding the solution to say 6 more root-finding problems!
Let's imagine that each root-finding solver did something like the above, and had two compilations where only one would suffice. Then in this example, we'd end up with compilation taking 2 6 = 12 times as long (2 from (b) and 6* from (d)). Indeed I've personally seen examples where compilation goes from around an hour to about five minutes. (!)
Optimistix (and all of its sibling libraries, e.g. Diffrax for differential equation solving) take care to ensure that they do a good job of this issue.
Okay, as for a concrete example of where do this kind of thing, let's look an example of a gradient descent with a line search. Algorithms of this type need to store two function evaluations at time: the start of a line search, and how far along a line search they are. They tend to iterate in an inner loop along the line search, and then when they are happy, they will accept that point, return to an outer loop, and start a new line seach. In addition, they also need a gradient evaluation for the start of a line search (to set the direction).
Let's look at where we do all of that:
Here, we have flattened these two loops together. (The loop-over-line-searches and the loop-within-a-line-search.) This means that we do not need to write ... = fn(...)
in two different places, and as such can immediately half our compilation time!
Moreover notice that little lin_to_grad
hiding out in there. We only sometimes need a gradient (when setting the direction for a new line search). Now if we wrote two separate branches:
lax.cond(need_gradient, lambda ...: fn(...), lambda ...: jax.value_and_grad(fn)(...))`
then we would also be approximately doubling compilation time, because both fn
and jax.value_and_grad(fn)
would need to be compiled. (And of course obviously if we just always evaluated jax.value_and_grad(fn)
, then we'd waste a lot of runtime computation by evaluating the gradient when we often don't need it.)
Our solution is for every function evaluation to actually linearize the function -- which both evaluates it and keeps around the extra intermediate values needed for later autodifferentiation -- and then only if required will we perform the second half of the autodifferentiation (transposing the tangent function, and propagating cotangents). In short, we've chopped backpropagation into two different pieces to ensure that we get optimal runtime and optimal compiletime performance.
As a final example of this kind of compilation stuff -- as part of the initial set-up for the solver, we need to (a) run through a function to check if it has any closed-over variables, and (b) set up some initial fn
evaluations for the evolving state of the solver. We take care to do both of these in the same pass. We do (a) here:
and in doing so we are able to get (b) for free here:
And all of that is just to do a simple gradient descent with line search! As you can imagine, more complicated solvers involve more work again :)
IMO, handling this kind of thing is precisely the reason that libraries like Optimistix exist! I think this is complicated stuff -- not to mention all the other stuff that went into Optimistix, like the search/descent abstraction, extensibility, problem reduction, etc. etc. -- and I'm a strong believer that good software is a key enabler of good science.
Ah I recognise where you're doing there. Flattening multi-part iterative code, plus linearising functions where grads maybe be called multiple times.
IMO, handling this kind of thing is precisely the reason that libraries like Optimistix exist!
It does take a lot of effort to write such good code. However, there is a diminishing return to the effort. I think if the library is limited to a smallish set of commonly used algorithms then it's logical to have such consistency. I'm interested in getting much more of the science community involved in sharing a diverse set of algorithms, in a more decentralised manner. My concern is that there will be considerably fewer contributions if a form factor is enforced. To manage the potential chaos I think a process of peer review and standardised regression benchmarks is essential. I also think there is a sweet spot of enabling optimistix-like rigour, by gently suggesting the usage of well-maintained tooling.
I'm interested in getting much more of the science community involved in sharing a diverse set of algorithms
Right! So now you're speaking to how we should handle extensibility. How to make it possible for someone to come up with a whole-new optimization algorithm and then for everyone to make use of that.
We've thought about this too :)
For Optimistix this is handled through the use of abstract base classes ('ABC'). The interface to a solver is defined by classes like optimistix.AbstractMinimiser
. Now optimistix.minimise
only requires that your solver be an instance of this class, and moreover the example I linked above is just an implementation of this ABC.
What this means, for example, is that third party A could create their own GitHub repository, which is a Python package with a class that implements this ABC:
# third-party-A/optimistix-extensions/cool_solver.py
from optimistix import AbstractMinimiser
class CoolSolver(AbstractMinimiser): ...
And third party B could then write something that uses this in their own code, flexibly and interchangeably with any other Optimistix routine:
# third-party-B/some-project/foo.py
from optimistix import minimise, GradientDescent
from optimistix_extensions import CoolSolver
if foo:
solver = GradientDescent(...)
else:
solver = CoolSolver(...)
minimise(solver, ...)
The decoupling here speaks to your 'distributed' point.
To manage the potential chaos I think a process of peer review and standardised regression benchmarks is essential. I also think there is a sweet spot of enabling optimistix-like rigour, by gently suggesting the usage of well-maintained tooling.
So notably, what the above design ensures is that algorithms can be implemented and tried out without requiring the high standards of rigour I described in my previous message, and they can still be consumed by everyone.
And then if such an algorithm becomes useful enough and well-implemented enough, it can be contributed to some standard place (potentially Optimistix itself).
That would only make sense if the algo being contributed had a corresponding ABC :) But we've discussed this in our meeting already. I'll close this ticket.
In this comment, optimistix is mentioned as improving performance by ensuring the objective function is compiled only once. Can you help shed light on what that exactly means and in which situations you're applying a trick, and perhaps point to code where I can take a look?