Closed jgallowa07 closed 5 months ago
The current approach doesn't deal with line search in a consistent way, because each call to run
will initialize the step size. I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run
. It would be preferable for our API to interface with the update
method of ProximalGradient
, rather than the run
method, so we can record loss trajectories in a consistent way. We could define our own run
command that iterates calls to update
until state.error = tol
or state.iteration = maxiter
, and outputs loss trajectory data from each iterate.
I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run
This is great - yes that makes sense to me. luckily we're actually deterministic otherwise.
We could define our own run command that iterates calls to update until state.error = tol or state.iteration = maxiter, and outputs loss trajectory data from each iterate.
This is amazing. It'll be until after re submission that I'll be able to do this - would happily assist if you had the bandwidth to PR.
I'll just note that this update will probably help with some issues I seem to be running into with the spike analysis.
I suspect this has to do with the tolerance, learning rate, memory and other things that may easier to control with a custom update()
loop - but It seems the model fits are not so robust after many training iterations. For context, we run the spike models for 30 independent rounds of 1000 iterations each (30K, total). These have the default tolerance set to 1e-4.
However, when we run those same models a single round of 100K iterations (everything else the same),
we can see some of the models have certainly over-fit to their data. I wonder how exactly the tolerance works with penalties? could it be the case that penalties added to the total cost is effecting the potential for the model to quit early?
This could also be related to #133 - as a ridge does seem to again, stabilize things. More testing will be needed here.
Relevant to this issue, it should be noted that the primary difference between single, and multi-step models optimization is the FISTA acceleration. In the single step models, the learning rate is reset at each step. When acceleration is turned off, these two approaches yield identical results.
We still need to add some sort of check on the convergence. To do this, we'll add a state
property to the Model
object. State, from jaxopt
gives the following properties:
class ProxGradState(NamedTuple):
"""Named tuple containing state information."""
iter_num: int
stepsize: float
error: float
aux: Optional[Any] = None
velocity: Optional[Any] = None
t: float = 1.0
I think what we want to do is simply check if the iter_num
is less than the max iterations requested .... This will tell us if the condition has been met and the model exited upon meeting the specified tolerance threshold
This issue is to note that
jaxopt
the package we use for optimizing our model, is merging intooptax
. While this poses no problem for the software as it stands it would certainly be desirable to make the switch once they merge inProximalGradient
.This would solve current problems with the way we do training steps just to get a convergence line. This leads to non-deterministic results (i.e. 1000 iterations for one step != 100 iterations for 10 steps). and future development, more generally.