Closed rneher closed 2 years ago
I've started working on an implementation using JAX, which could yield speedups via just-in-time compilation without changing the current numerical methodology. But before going further, we wanted to ask would such a solution be acceptable?
thanks! I'll look into it.
Hi @WSDeWitt! this looks interesting. What are the operations that JAX tends to speed up? I suppose these are big matrix operations and anything where the same function is run on many datapoints? If so, this might be quite handy for the array operations used to calculate the sequence likelihoods etc. But for the calculation of the timings I wouldn't expect that much speed-up. But happy to touch base and discuss -- I might not understand what JAX is really about.
It can also speed up code with loops. IIUC the just-in-time compilation does this by unwrapping iterates for loops of a fixed size.
In the attached Jupyter notebook I cooked up a toy example showing ~1000x speedup for a simple quadrature.
this looks promising. Let's chat about it!
Perhaps autodiff for https://github.com/neherlab/treetime/blob/master/treetime/merger_models.py#L173 ?
Is there any progress on this? Can someone share the JAX commands to compile TreeTime to be able to run in parallel over large dataset? Even a draft pretrial version. I think there is an urgent need to use a more optimized compile version of TreeTime to handle tens of thousands data in parallel ... I'd be willing to try on different machines configurations on AWS and on HPCCs.
I will use TreeTime (within NextStrain's Augur) on HPCC ... I know that TreeTime will be a bottleneck ... given it is all written in Python, has anyone thought to embed mpi4py within the python codes? ... I am not a Python programmer (and I acknowledge Python is not the ideal language to run on HPCC) but perhaps "Message Passing Interface" (as in Fortran and C++) would make things faster to chase polytomies on a phylogenetic tree. Indeed, MPI is an ideal wrapper to "spread" the workload to either different cluster's nodes or different cluster's cores (or both) on a HPCC. On principle, a given TreeTime algorithm would be split to different cores/nodes handling different section sections of the phylogenetic tree. Now I admit, it is easier to say than to do: it may not be easy but it may improve computing time indeed.
On the same note, I tried to have polytomy and near zero length branch removed by IQ-TREE (the -czb option) before getting to TreeTime (which would be used for rooting and tree clock/dating), that algorithm within IQ-TREE is not parallelized either and is a bottleneck as well; however, if performed by IQ-TREE before TreeTime, it speeds up much much indeed the work done by TreeTime ... Of course, it does not change overall the total computing time, either IQ-TREE does the slow work or TreeTime, either ways it is a bottleneck ...
My point is that these algorithms should be parallelized ... can MPI (eg., Python's mpi4py) worth considering? Has anyone done so? I am always happy to test on our different HPCC flavors. If we were not in such state of emergency with this COVID-19 pandemic, I would have loved to program this myself within the raw Python codes but that won't be an option anytime soon for me.
Thanks all for the good work!
SEbastien
@sebdart One major speedup for treetime would be the FFT implementation of the convolution. This would reduce the computational cost of the marginal reconstruction by a large factor. But there are edge cases in which this doesn't seem robust and I haven't managed to get to the bottom of it yet. The draft PR is here:
https://github.com/neherlab/treetime/pull/121
Other than that, I fully agree that TreeTime should be parallelized, but it wasn't designed with this in mind. As @WSDeWitt discovered, most steps access deep common objects. This isn't strictly necessary, just happens to be the way this got implemented. But fixing this would require a substantial redesign.
For some of the basic sequence evolution parts, parallelization via numpy might be an option, but this isn't the bottleneck for the time-scaled tree part.
I don't have much experience here, but I wonder if Numba's @jit
in "object mode" would allow for speedup of pure numpy loops/ops while not requiring redesign of the TreeTime interfaces.
version 0.9.0 now uses FFT for convolutions.
The problem
TreeTime is (up to polytomy resolution) a method that scales linearly in tree size, but some steps are slow. This is partly due to the fact that its in python, partly due to suboptimal implementation. In timetree estimation, the slowest step is the calculation of convolutions or maxima of functions of the form
f(t-tau)g(tau)
. While not intrinsically hard, the challenge is to make this robust and numerically stable for branch length of order 1e-7 to 10.The probability that a branch as a particular length or that a node sits in a particular position is represented as a linear interpolation object of the logarithm of the function. The pivot points of the interpolation are chosen densely around the peak of the distribution and sparsely everywhere else.
In principle, the complexity of this problem should be log-linear in the required accuracy, but the current implementation is quadratic in the accuracy.
Possible solutions
Use FFT to calculate convolutions for marginal inference
This requires a larger regular grip of points on which the functions are stored. An experimental version of this strategy is implemented in the branch
fft
(https://github.com/neherlab/treetime/tree/fft). This implements FFT as convolutionhttps://github.com/neherlab/treetime/blob/fft/treetime/node_interpolator.py#L193
This greatly accelerates the inference, but only works for purely marginal inference which can result in inconsistent node placing (as every node is maximized while tracing over all others). Edge cases, robustness, and stability are not established and this is still buggy.
More efficient numerical optimization of integrands in joint inference
We previously just searched for the peak of the function on the parse grid. Numerical optimization should solve this in log time and we can afford a denser grid in that case.
Pitfalls
Generally, this problem is tricky since one has to observe hard constraints (sampling date ordering) and exponentially small numbers.