Fitting η(t) and μ(t) is very fast, a few seconds using a time grid with hundreds of epochs (previously we were waiting several minutes on time grids with tens of epochs) 🏎
Here are results using data simulated with the Tennesen η(t) in stdpopsim and a TCC-like pulse affecting a few mutation types in μ(t). I'm especially impressed that we recover the out-of-Africa bottleneck so well, considering that our smoothness regularization disfavors such sharp features. 😮
Details about several handy upgrades
Block coordinate descent for η(t) and μ(t)
The function kSFS.coord_desc() performs one iteration of block coordinate descent, with one block optimizing self.η, and a second block optimizing self.μ. The function signature and docstring are shown in the box below. TV and spline regularization parameters are available for both η(t) and μ(t), so we can seek L2-smooth or L1-smooth histories for either. The function returns the cost (regularized loss) and can be used repeatedly in a loop to seek convergence.
Both the η(t) and μ(t) blocks mentioned above call utils.acc_prox_grad_descent(), which takes an arbitrary objective function defined by a differentiable piece, a non-differentiable piece, and a proximal operator corresponding to the latter. The function signature and docstring are shown in the box below.
https://github.com/harrispopgen/mushi/blob/7a930fcfd99a4bc811be23ebf6e3c743fff65f30/utils.py#L94-L118
Automatic differentiation and just-in-time compilation with JAX
There are no analytically-coded derivatives in this branch. Instead we use jax.grad. We also use jax.jit to compile for speed. In addition to substantially simplifying the code, automatic differentiation let's us quickly prototype alternative regularization approaches.
TL;DR
stdpopsim
and a TCC-like pulse affecting a few mutation types in μ(t). I'm especially impressed that we recover the out-of-Africa bottleneck so well, considering that our smoothness regularization disfavors such sharp features. 😮Details about several handy upgrades
Block coordinate descent for η(t) and μ(t)
The function
kSFS.coord_desc()
performs one iteration of block coordinate descent, with one block optimizingself.η
, and a second block optimizingself.μ
. The function signature and docstring are shown in the box below. TV and spline regularization parameters are available for both η(t) and μ(t), so we can seek L2-smooth or L1-smooth histories for either. The function returns the cost (regularized loss) and can be used repeatedly in a loop to seek convergence.https://github.com/harrispopgen/mushi/blob/7a930fcfd99a4bc811be23ebf6e3c743fff65f30/mushi.py#L138-L171
Generic Nesterov accelerated proximal gradient descent utility
Both the η(t) and μ(t) blocks mentioned above call
utils.acc_prox_grad_descent()
, which takes an arbitrary objective function defined by a differentiable piece, a non-differentiable piece, and a proximal operator corresponding to the latter. The function signature and docstring are shown in the box below. https://github.com/harrispopgen/mushi/blob/7a930fcfd99a4bc811be23ebf6e3c743fff65f30/utils.py#L94-L118Automatic differentiation and just-in-time compilation with JAX
There are no analytically-coded derivatives in this branch. Instead we use
jax.grad
. We also usejax.jit
to compile for speed. In addition to substantially simplifying the code, automatic differentiation let's us quickly prototype alternative regularization approaches.