matsengrp / multidms

Joint modeling of multiple deep mutational scanning experiments
https://matsengrp.github.io/multidms
MIT License
1 stars 0 forks source link

Convergence criteria #137

Closed jgallowa07 closed 5 months ago

jgallowa07 commented 6 months ago

This issue is to note that jaxopt the package we use for optimizing our model, is merging into optax. While this poses no problem for the software as it stands it would certainly be desirable to make the switch once they merge in ProximalGradient.

This would solve current problems with the way we do training steps just to get a convergence line. This leads to non-deterministic results (i.e. 1000 iterations for one step != 100 iterations for 10 steps). and future development, more generally.

wsdewitt commented 6 months ago

The current approach doesn't deal with line search in a consistent way, because each call to run will initialize the step size. I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run. It would be preferable for our API to interface with the update method of ProximalGradient, rather than the run method, so we can record loss trajectories in a consistent way. We could define our own run command that iterates calls to update until state.error = tol or state.iteration = maxiter, and outputs loss trajectory data from each iterate.

jgallowa07 commented 6 months ago

I wouldn't call this nondeterministic, but it is a deterministic function of how we partition iterations into calls to run

This is great - yes that makes sense to me. luckily we're actually deterministic otherwise.

We could define our own run command that iterates calls to update until state.error = tol or state.iteration = maxiter, and outputs loss trajectory data from each iterate.

This is amazing. It'll be until after re submission that I'll be able to do this - would happily assist if you had the bandwidth to PR.

jgallowa07 commented 6 months ago

I'll just note that this update will probably help with some issues I seem to be running into with the spike analysis.

I suspect this has to do with the tolerance, learning rate, memory and other things that may easier to control with a custom update() loop - but It seems the model fits are not so robust after many training iterations. For context, we run the spike models for 30 independent rounds of 1000 iterations each (30K, total). These have the default tolerance set to 1e-4.

Screenshot from 2024-03-05 06-44-15

However, when we run those same models a single round of 100K iterations (everything else the same),

Screenshot from 2024-03-05 06-57-43

we can see some of the models have certainly over-fit to their data. I wonder how exactly the tolerance works with penalties? could it be the case that penalties added to the total cost is effecting the potential for the model to quit early?

This could also be related to #133 - as a ridge does seem to again, stabilize things. More testing will be needed here.

jgallowa07 commented 5 months ago

Relevant to this issue, it should be noted that the primary difference between single, and multi-step models optimization is the FISTA acceleration. In the single step models, the learning rate is reset at each step. When acceleration is turned off, these two approaches yield identical results.

jgallowa07 commented 5 months ago

We still need to add some sort of check on the convergence. To do this, we'll add a state property to the Model object. State, from jaxopt gives the following properties:

class ProxGradState(NamedTuple):
  """Named tuple containing state information."""
  iter_num: int
  stepsize: float
  error: float
  aux: Optional[Any] = None
  velocity: Optional[Any] = None
  t: float = 1.0

I think what we want to do is simply check if the iter_num is less than the max iterations requested .... This will tell us if the condition has been met and the model exited upon meeting the specified tolerance threshold