desihub / fastspecfit

Fast spectral synthesis and emission-line fitting of DESI spectra.
https://fastspecfit.readthedocs.org
BSD 3-Clause "New" or "Revised" License
13 stars 2 forks source link

fastspecfit at risk of becoming slowspecfit #98

Open moustakas opened 1 year ago

moustakas commented 1 year ago

The code refactoring in #92, #95, and #96 added much more robust and extensive fitting capabilities (even after removing the dependence on the very slow astropy.modeling classes) at the expense of speed. For example, for the non-linear least-squares fitting which is used to model the emission lines, I switched from using the Levenberg-Marquardt (lm) algorithm in scipy.optimize.least_square to the Trust Region Reflective (trf) algorithm, which incorporates bounds into the minimization procedure and is significantly more robust---but slower!

The bottleneck in trf is the numerical differentiation step (see, e.g., https://stackoverflow.com/questions/68507176/faster-scipy-optimizations), and so I think the code is a great candidate to be ported to GPUs, where algorithmic differentiation using, e.g., Jax can be factors of many faster.

In addition, for the stellar continuum fitting (including inferring the velocity dispersion), the main algorithm is the non-negative least-squares fitting provided by scipy.optimize.nnls. Despite its Fortran-bindings, nnls is still a bottleneck; the other slow pieces are the resampling of the templates at the redshift of each object and the convolution with the resolution matrix (here, I'm using Redrock's trapz_rebin algorithm, which already uses numba/jit for speed---but still takes non-negligible time).

@dmargala @marcelo-alvarez @craigwarner-ufastro @sbailey and others---I'd be grateful for any thoughts or insight for how to proceed.

In the meantime, here are some profiling results:

Logging into perlmutter:

source /global/cfs/cdirs/desi/software/desi_environment.sh 23.1
module load fastspecfit/2.0.0

python -m cProfile -o fastspec.prof /global/common/software/desi/perlmutter/desiconda/20230111-2.1.0/code/fastspecfit/2.0.0/bin/fastspec \
  $DESI_ROOT/spectro/redux/fuji/tiles/cumulative/80613/20210324/redrock-4-80613-thru20210324.fits \
  -o fastspec.fits --targetids 39633345008634465

I get the following profile: (Note: I wasn't sure how to launch snakeviz on perlmutter, so I copied the fastspec.prof file to my laptop. Also, I'm ignoring the I/O at the moment because those steps are slow for a single object but should be a small fraction of the total time when fitting a full healpixel):

Screenshot 2023-01-26 at 6 09 19 PM Screenshot 2023-01-26 at 6 09 53 PM
moustakas commented 1 year ago

Here are the results from fitting 50 objects:

python -m cProfile -o fastspec-50.prof /global/common/software/desi/perlmutter/desiconda/20230111-2.1.0/code/fastspecfit/2.0.0/bin/fastspec   \
  $DESI_ROOT/spectro/redux/fuji/tiles/cumulative/80613/20210324/redrock-4-80613-thru20210324.fits \
  --ntargets 50 -o fastspec-50.fits

As expected, the I/O time becomes negligible (and this is with perlmutter in a degraded state) and the time is entirely dominated by the line-fitting (foremost) and the continuum-fitting (secondarily). Within the continuum-fitting, nnls is actually pretty neglible compared to the trapezoidal rebinning (called by smooth_and_resample), the Gaussian-broadening (part of the velocity dispersion fitting), and the filter convolutions (used to synthesize photometry).

Screenshot 2023-01-27 at 6 27 40 AM

Zooming into the continuum-fitting:

Screenshot 2023-01-27 at 6 34 46 AM
aphearin commented 1 year ago

@moustakas kudos for shining fastspecfit down to the bone and eliminating all of the bottlenecks besides the actual fitting algorithm. For problems of modest dimension (Ndim<~50-100), the most robust and fastest fitter I've used that implements bounds (kind of by far) is the L-BFGS algorithm, but this algorithm leans hard on accurate gradients because it uses the Hessian to condition the gradient descent. Higher-order derivs can be slow to compute numerically, and it can get really tiresome to check their accuracy numerically, but they come with guaranteed machine-precision "for free" if the likelihood/cost function is implemented in an autodiff library like JAX. Speedup factors can be an order of magnitude or more when deploying algorithms based on higher-order gradients on modern tensorcore GPUs.

The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones. The DSPS paper has some in-the-weeds pointers for JAX implementations for the case of a traditional SPS approach to the problem. The DSPS library is open-source, but this is really more of a collection of kernels and not yet a well-documented library so I imagine it might not be very transparent to look at independently. I'd be happy to chat about this further in case you decide to go down this road.

moustakas commented 1 year ago

Thanks for the comments @aphearin. I'll take a look at your links.

The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones.

At @dmargala's urging, I tried to write the fitting "guts" using as much pure-numpy as I could. For example, here's the objective function which is passed to scipy.optimize.least_squares https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L74-L113

with one notable (but not dominant) bottleneck being the construction of the emission-line model (a sum of Gaussians) for the set of parameters being optimized (ranging from 10-50 parameters)-- https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L28-L72

For the optimization, I was thinking of using one of the constrained optimization algorithms in https://jaxopt.github.io/stable/constrained.html, although I have no idea how to get JAX (and JAXlib) running at NERSC (whether in a Docker/shifter container or not; see, e.g., https://github.com/google/jax/issues/6340).

aphearin commented 1 year ago

OK this looks pretty tractable to me actually. The usual things that need to be rewritten are control flow within for loops, which takes a little fiddling but is usually not too bad. One blocker to be aware of are while loops, which are a no-go for taking gradients with JAX (while loops are actually supported by JAX, but I think reverse-mode differentiation through them is unsupported and I don't know whether support is coming anytime soon). I think maybe rewriting the _trapz_rebin kernel in JAX might be first focal point, which doesn't look to bad at first glance.

moustakas commented 1 year ago

trapz_rebin is also one of the bottlenecks in Redrock, actually (from which I stole this code shamelessly but with permission from @sbailey), but I'm not sure if someone is already rewriting it away from numba.jit and into another GPU-optimized way.

aphearin commented 1 year ago

I'm not seeing anything in the _trapz_rebin kernel that looks like a blocker

craigwarner-ufastro commented 1 year ago

@moustakas in case its helpful I recently wrote a GPU-optimized version of trapz_rebin that has been merged into the main branch of Redrock now. If you are rebinning to many z then it is much faster.

erwanp commented 1 year ago

Hello, reading your comments about Jax-accelerated spectral synthesis : do you know of Exojax and is it suited for your application ?