Closed LouisDesdoigts closed 10 months ago
Hmm, so my first question is what kind of behaviour you're seeking, specifically?
Basically, what goes wrong with fixing an rtol, which already handles multiple scales?
Haha yeah sorry probably should have just provided an example of the desired functionality.
import jax
import jax.numpy as np
import equinox as eqx
import optimistix as optx
import optax
# Set up Model
class Linear(eqx.Module):
m: jax.Array
b: jax.Array
def __init__(self, m, b):
self.m = m
self.b = b
def __call__(self, x):
return self.m * x + self.b
# Simple loss
def loss_fn(model, args):
x, y = args
return np.mean((model(x) - y) ** 2)
# Normal optax optimiser
linear = Linear(np.array(1.0), np.array(0.0))
param_spec = eqx.tree_at(lambda x: (x.m, x.b), linear, ("m", "b"))
optim = optax.multi_transform({"m": optax.adam(1e-3), "b": optax.adam(1e3)}, param_spec)
# Per-leaf atol and rtol
rtol = eqx.tree_at(lambda x: (x.m, x.b), linear, (0.1, 10))
atol = eqx.tree_at(lambda x: (x.m, x.b), linear, (0.1, 10))
# Optimistix minimiser
solver = optx.OptaxMinimiser(
optim,
rtol=(1e-3, rtol), # f-space rtol, y-space rtol
atol=(1e-3, atol), # f-space atol, y-space atol
)
So in this example my termination condition would have a different rtol
and atol
for the loss (f-space) and for each leaf (y-space), as opposed to having the same termination value applied to everything. Does that clarify my question?
Its also possible I have miss-understood something about the how the termination condtion works, so please let me know if thats the case!
Sure! Sorry, to be clear, I understand the ask, and the fact that Optimistix doesn't support this right now. What I'm trying to better understand (before thinking about a possible solution) is why this kind of mixed-tolerance is a desirable thing to want in the first place.
Typically the reason for having rtol
is so that you can be sure of getting solutions whose accuracy scales linearly with the scale of the problem (rtol
). (And likewise atol
exists to get scale-invariant accuracies.)
So if that isn't sufficient, is it because you want some nonlinear function scale->accuracy? What kind of nonlinear function / why?
Ah okay, yeah let me explain my reasoning as I'm not well versed on the theory behind all this stuff so I might just not understand the use of the rtol
and atol
values correctly.
We are optimising forwards models with great diversity in parameter scales, on problem with large diversity in likelihoods. Taking a two parameter example we might be trying to find both the position and brightness of a star imaged through a telescope. On-sky position is typically measured using arcseconds of order 1e-3, and brightness is typically measured in photons which can range in values from 1e4 - 1e12. Having a single termination value for both of these parameters is difficult.
The on-sky position measurement is relative to the optics, so if our true values can be arbitrarily close to zero the desired rtol
would need to be somewhat large (or possibly even ignored), and the atol
would want to be ~1e-3.
The brightness of the star however could cover a many orders of magnitude, so there isn't really a concrete atol
value that makes sense (ie If we take 1e-3 to match the position that would be far too small). In this case we would need an rtol
of ~1e-3.
That is essentially the core of my issue, and maybe this should be framed as more of a question - How you you go about devising convergence criteria for a problem like this? Am I thinking about the convergence incorrectly in the first place?
From what you've said, I think suspect taking atol=1e-3
and rtol=1e-3
, or thereabouts, should be about right. The overall scale is given by atol + rtol * value
, so the rtol
will be negligible for the on-sky position and the atol
will be negligible for the brightness.
For what it's worth, if we were to change this, then I'd be tempted to do this through the norm
instead -- perhaps introduce a separate norm for the y
and f
spaces. Then you can scale each component however might be desired.
Yeah that was just a small example - In practice we are optimizing over a dozen unique sets of parameters, so trying to find a balance of between every leaf type becomes unwieldy.
I think I'm starting to get my head around the way this works under the hood. Building a robust norm
function seems like it could solve this problem, and possibly also be used as a way to normalize parameter values through a custom solver.
Ultimately all of these questions are also in the context of #20, where parameter scales are just a problem in general so a robust solution would also need to be cognizant of that too. I'll have to look into this more when I have some more time to consider both the problem and solutions more carefully.
Thanks for the info!
You're welcome! Let us know how it goes -- we can definitely add something to the API if your problem ends up being tricky to implement as-is.
So I would love to be able to pass in a pytree for both the
rtol
andatol
values, in a similar vein to how you can set individual learning rates for each leaf inoptax
. This would make a lot of sense for most of my work which has pytree leaves with vastly different parameter scales.Looking at the termination condition code, it looks like this hasn't been made an option because the values are applied to both the pytree leaf values ('y space') and the loss value ('f space').
From what I can tell there would two ways to get this behavior:
Allow custom termination condition.
I don't think this is the best solution as the cauchy termination is already wrapped up with input
norm
function/pytree.Allow pytree inputs for
rtol
andatol
.I think this could be done relatively easily by allowing the f-space and y-space conditions to be individually specified via a tuple like this: (f-space rtol (float), y-space rtol (float or pytree)). This would preserve the present syntax, while also allowing users full freedom over the termination condition.
Anyway maybe there is a better way that I'm missing, but having this functionality is actually somewhat essential for using optimistix in my work in the long run, so let me know your thoughts!