harrispopgen / mushi

[mu]tation [s]pectrum [h]istory [i]nference
https://harrispopgen.github.io/mushi/
MIT License
24 stars 5 forks source link

version 1 release candidate #68

Closed wsdewitt closed 3 years ago

wsdewitt commented 3 years ago

Summary

Note: ignore changes to files in the docs/ directory, this is the documentation build.

What to look at first

Under the hood

wsdewitt commented 3 years ago

Apparently something in today's release of the Jax package is broken, which is why the integration tests above failed (and why you might be getting the same error if doing a fresh install) on import jax.numpy as np. https://github.com/google/jax/issues/5374

Workaround:

pip install jaxlib==0.1.57
kamdh commented 3 years ago

I'm not certain whether prox_tv is a needed import any more. Now that we have our own solver for trend filtering, I figure we can use it to solve for the total variation case as well, right? That would make the code more self-contained, and maybe allow for jit.

kamdh commented 3 years ago

This may be more of a feature request, but I think it makes sense to refactor the optimization code into its own module. This could end up being very useful for other people. We could also consider trying to get it pulled into some bigger python optimization library (not sure what at the moment).

wsdewitt commented 3 years ago

@kamdh

I'm not certain whether prox_tv is a needed import any more. Now that we have our own solver for trend filtering, I figure we can use it to solve for the total variation case as well, right? That would make the code more self-contained, and maybe allow for jit.

The magic in the special ADMM (from the paper cited in my docstring) is that it recurses to the 0th order case, and thus leverages those super fast solvers (like tv-prox). Boyd et al. have a trend filter ADMM that does not do this, and it is apparently slower, but I have not tried to implement it. It's possible that the non-recursive ADMM would be faster with JIT (and it would be nice to lose the dependency so we can upgrade python harrispopgen/mushi#59), but unfortunately the banded Cholesky solves aren't yet supported in JAX, and that's the most expensive part of the ADMM update.

image
wsdewitt commented 3 years ago

@kamdh

This may be more of a feature request, but I think it makes sense to refactor the optimization code into its own module. This could end up being very useful for other people. We could also consider trying to get it pulled into some bigger python optimization library (not sure what at the moment).

I totally agree, the optimization module is pretty impressive in its own right, and shouldn't be chained to this use case. We can spin it out with git filter-branch in a way that preserves commit history/contributions. Maybe we don't need to do that for this PR merge though, so want to open an issue for this and we can discuss separately (perhaps sometime after Feb 1)?

You might also want to play with the new (object oriented) optimization module

kamdh commented 3 years ago

I totally agree, the optimization module is pretty impressive in its own right, and shouldn't be chained to this use case. We can spin it out with git filter-branch in a way that preserves commit history/contributions. Maybe we don't need to do that for this PR merge though, so want to open an issue for this and we can discuss separately (perhaps sometime after Feb 1)?

Sounds like a plan to me.