harrispopgen / mushi

[mu]tation [s]pectrum [h]istory [i]nference
https://harrispopgen.github.io/mushi/
MIT License
24 stars 6 forks source link

Autodiff2 #33

Closed wsdewitt closed 5 years ago

wsdewitt commented 5 years ago

TL;DR

Details about several handy upgrades

Block coordinate descent for η(t) and μ(t)

The function kSFS.coord_desc() performs one iteration of block coordinate descent, with one block optimizing self.η, and a second block optimizing self.μ. The function signature and docstring are shown in the box below. TV and spline regularization parameters are available for both η(t) and μ(t), so we can seek L2-smooth or L1-smooth histories for either. The function returns the cost (regularized loss) and can be used repeatedly in a loop to seek convergence.

https://github.com/harrispopgen/mushi/blob/7a930fcfd99a4bc811be23ebf6e3c743fff65f30/mushi.py#L138-L171

Generic Nesterov accelerated proximal gradient descent utility

Both the η(t) and μ(t) blocks mentioned above call utils.acc_prox_grad_descent(), which takes an arbitrary objective function defined by a differentiable piece, a non-differentiable piece, and a proximal operator corresponding to the latter. The function signature and docstring are shown in the box below. https://github.com/harrispopgen/mushi/blob/7a930fcfd99a4bc811be23ebf6e3c743fff65f30/utils.py#L94-L118

Automatic differentiation and just-in-time compilation with JAX

There are no analytically-coded derivatives in this branch. Instead we use jax.grad. We also use jax.jit to compile for speed. In addition to substantially simplifying the code, automatic differentiation let's us quickly prototype alternative regularization approaches.