Closed wsdewitt closed 3 years ago
Apparently something in today's release of the Jax package is broken, which is why the integration tests above failed (and why you might be getting the same error if doing a fresh install) on import jax.numpy as np
. https://github.com/google/jax/issues/5374
Workaround:
pip install jaxlib==0.1.57
I'm not certain whether prox_tv is a needed import any more. Now that we have our own solver for trend filtering, I figure we can use it to solve for the total variation case as well, right? That would make the code more self-contained, and maybe allow for jit.
This may be more of a feature request, but I think it makes sense to refactor the optimization code into its own module. This could end up being very useful for other people. We could also consider trying to get it pulled into some bigger python optimization library (not sure what at the moment).
@kamdh
I'm not certain whether prox_tv is a needed import any more. Now that we have our own solver for trend filtering, I figure we can use it to solve for the total variation case as well, right? That would make the code more self-contained, and maybe allow for jit.
The magic in the special ADMM (from the paper cited in my docstring) is that it recurses to the 0th order case, and thus leverages those super fast solvers (like tv-prox). Boyd et al. have a trend filter ADMM that does not do this, and it is apparently slower, but I have not tried to implement it. It's possible that the non-recursive ADMM would be faster with JIT (and it would be nice to lose the dependency so we can upgrade python harrispopgen/mushi#59), but unfortunately the banded Cholesky solves aren't yet supported in JAX, and that's the most expensive part of the ADMM update.
@kamdh
This may be more of a feature request, but I think it makes sense to refactor the optimization code into its own module. This could end up being very useful for other people. We could also consider trying to get it pulled into some bigger python optimization library (not sure what at the moment).
I totally agree, the optimization module is pretty impressive in its own right, and shouldn't be chained to this use case. We can spin it out with git filter-branch
in a way that preserves commit history/contributions. Maybe we don't need to do that for this PR merge though, so want to open an issue for this and we can discuss separately (perhaps sometime after Feb 1)?
You might also want to play with the new (object oriented) optimization module
I totally agree, the optimization module is pretty impressive in its own right, and shouldn't be chained to this use case. We can spin it out with
git filter-branch
in a way that preserves commit history/contributions. Maybe we don't need to do that for this PR merge though, so want to open an issue for this and we can discuss separately (perhaps sometime after Feb 1)?
Sounds like a plan to me.
Summary
Note: ignore changes to files in the
docs/
directory, this is the documentation build.What to look at first
kSFS.infer_eta
andkSFS.infer_mush
. The API docs on the new interface can be viewed by openingdocs/index.html
in your clone.docsrc/notebooks/simulation.ipynb
, but be careful about committing your changes, because this notebook drives docs content.Under the hood
mushi.optimization
module with abstraction and inheritance to avoid a lot of duplicated code. Added a trend filtering optimizer class based on the recursive ADMM of Ramdas and Tibshirani (this serves as the prox operator in the outer optimization routine when fitting demography or mush). I was able to get this running quite fast by caching Cholesky decompositions and using the fast prox-tv module for dual variable updates. I find that about 20 iterations of ADMM are plenty (although the default is 100).utils
modelr
that is a learned parameter.pts
andta
when runningkSFS.infer_eta()
.mushi.loss_functions
. Inference can be done using any loss function from this module.