ICSM / ampere

A tool to fit the SED and spectra of dusty objects to constrain, among other things, the dust properties and mineralogy
6 stars 2 forks source link

Low-level optimisations #67

Open pscicluna opened 1 year ago

pscicluna commented 1 year ago

I often still write python as though it were FORTRAN, but there's no compiler to save me anymore. Hence, there are a bunch of places where relatively low-level python optimsation is possible.

For a few examples, the "vectorised" posterior/prior/likelihood/simulator methods aren't really vectorised, they just loop over batches of parameters. There are faster ways of doing this, such as map() (which most inference packages use under the hood if we don't tell it the functions are vectorised), np.vectorize (if slightly buggy) or just insisting that the underlying likelihood/prior/simulate functions are actually vectorised if you turn this on.

Similarly, profiling shows that - for relatively fast models - a lot of time is spent in relatively few places, particularly evaluating probability distributions, or constructing, factorising and sampling from covariance matrices (not unexpected!). For probabilities, some quick tests show that Torch is about 3 times faster than scipy, so this could give us a boost, and vectorising the evaluation provides orders of magnitude improvement (e.g. evaluating a vector of 1000 random numbers takes only 30% longer than evaluating one random number). It is possible that other libraries (e.g. numpyro, since it is based on jax) might be even faster than torch, but we already have torch in our dependencies. However, jax might be worth adding, since it could provide a number of optimisations throughout the code (and in future let us do automatic gradients!). Covariance matrices are annoying, and we need improved ways of handling them - Cholesky decomposition crashes if the matrix is somehow not PSD, and so I switched the default to eigenvalue decomposition, but that is slower. Torch may again be able to provide a solution, since they have a lot of framework set up for lazy evaluation of kernel tensors (which is basically what we want to do) including jitter to ensure PSDness, but will require some benchmarking and testing. If similar solutions exist in pure numpy, or with jax or other packages, they may also be worth investigating.

Please discuss these or further ideas here! Whenever we're ready to tackle any particular aspect of this, we can spin out new issues that are more specific.

pscicluna commented 1 year ago

After more digging, we might as well jump into jax now, even if it places some limitations on us in the short term. It is building a big ecosystem fast, and can be very easily built in to existing python code. The rough plan would be to have all Data classes with jax support baked in, and use GPJax for fast and memory-efficient handling of covariance matrices and likelihood evaluations. Models will have to be smart enough to tell whether they will work entirely with JAX or not and provide a fallback - the easiest thing may be to provide two inheritance paths, one Model and one JAXModel and raise an exception if you try to use the wrong one. I hope this can also make it easier to develop new data and model stuff, since it will make things a lot cleaner.