Open moustakas opened 1 year ago
Here are the results from fitting 50 objects:
python -m cProfile -o fastspec-50.prof /global/common/software/desi/perlmutter/desiconda/20230111-2.1.0/code/fastspecfit/2.0.0/bin/fastspec \
$DESI_ROOT/spectro/redux/fuji/tiles/cumulative/80613/20210324/redrock-4-80613-thru20210324.fits \
--ntargets 50 -o fastspec-50.fits
As expected, the I/O time becomes negligible (and this is with perlmutter in a degraded state) and the time is entirely dominated by the line-fitting (foremost) and the continuum-fitting (secondarily). Within the continuum-fitting, nnls
is actually pretty neglible compared to the trapezoidal rebinning (called by smooth_and_resample
), the Gaussian-broadening (part of the velocity dispersion fitting), and the filter convolutions (used to synthesize photometry).
Zooming into the continuum-fitting:
@moustakas kudos for shining fastspecfit down to the bone and eliminating all of the bottlenecks besides the actual fitting algorithm. For problems of modest dimension (Ndim<~50-100), the most robust and fastest fitter I've used that implements bounds (kind of by far) is the L-BFGS algorithm, but this algorithm leans hard on accurate gradients because it uses the Hessian to condition the gradient descent. Higher-order derivs can be slow to compute numerically, and it can get really tiresome to check their accuracy numerically, but they come with guaranteed machine-precision "for free" if the likelihood/cost function is implemented in an autodiff library like JAX. Speedup factors can be an order of magnitude or more when deploying algorithms based on higher-order gradients on modern tensorcore GPUs.
The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones. The DSPS paper has some in-the-weeds pointers for JAX implementations for the case of a traditional SPS approach to the problem. The DSPS library is open-source, but this is really more of a collection of kernels and not yet a well-documented library so I imagine it might not be very transparent to look at independently. I'd be happy to chat about this further in case you decide to go down this road.
Thanks for the comments @aphearin. I'll take a look at your links.
The level of pain it would require to reimplement the internals of fastspecfit into JAX depends entirely on the nitty gritty details of the algorithms you're using - sometimes this is rather straightforward, sometimes it's a bit of a lift, and sometimes it requires abandoning certain sophisticated iterative algorithms in favor of simpler non-adaptive ones.
At @dmargala's urging, I tried to write the fitting "guts" using as much pure-numpy as I could. For example, here's the objective function which is passed to scipy.optimize.least_squares
https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L74-L113
with one notable (but not dominant) bottleneck being the construction of the emission-line model (a sum of Gaussians) for the set of parameters being optimized (ranging from 10-50 parameters)-- https://github.com/desihub/fastspecfit/blob/main/py/fastspecfit/emlines.py#L28-L72
For the optimization, I was thinking of using one of the constrained optimization algorithms in https://jaxopt.github.io/stable/constrained.html, although I have no idea how to get JAX (and JAXlib) running at NERSC (whether in a Docker/shifter container or not; see, e.g., https://github.com/google/jax/issues/6340).
OK this looks pretty tractable to me actually. The usual things that need to be rewritten are control flow within for loops, which takes a little fiddling but is usually not too bad. One blocker to be aware of are while loops, which are a no-go for taking gradients with JAX (while loops are actually supported by JAX, but I think reverse-mode differentiation through them is unsupported and I don't know whether support is coming anytime soon). I think maybe rewriting the _trapz_rebin kernel in JAX might be first focal point, which doesn't look to bad at first glance.
trapz_rebin
is also one of the bottlenecks in Redrock
, actually (from which I stole this code shamelessly but with permission from @sbailey), but I'm not sure if someone is already rewriting it away from numba.jit
and into another GPU-optimized way.
I'm not seeing anything in the _trapz_rebin kernel that looks like a blocker
@moustakas in case its helpful I recently wrote a GPU-optimized version of trapz_rebin
that has been merged into the main branch of Redrock
now. If you are rebinning to many z
then it is much faster.
The code refactoring in #92, #95, and #96 added much more robust and extensive fitting capabilities (even after removing the dependence on the very slow
astropy.modeling
classes) at the expense of speed. For example, for the non-linear least-squares fitting which is used to model the emission lines, I switched from using the Levenberg-Marquardt (lm
) algorithm in scipy.optimize.least_square to the Trust Region Reflective (trf
) algorithm, which incorporates bounds into the minimization procedure and is significantly more robust---but slower!The bottleneck in
trf
is the numerical differentiation step (see, e.g., https://stackoverflow.com/questions/68507176/faster-scipy-optimizations), and so I think the code is a great candidate to be ported to GPUs, where algorithmic differentiation using, e.g., Jax can be factors of many faster.In addition, for the stellar continuum fitting (including inferring the velocity dispersion), the main algorithm is the non-negative least-squares fitting provided by scipy.optimize.nnls. Despite its Fortran-bindings,
nnls
is still a bottleneck; the other slow pieces are the resampling of the templates at the redshift of each object and the convolution with the resolution matrix (here, I'm using Redrock's trapz_rebin algorithm, which already usesnumba/jit
for speed---but still takes non-negligible time).@dmargala @marcelo-alvarez @craigwarner-ufastro @sbailey and others---I'd be grateful for any thoughts or insight for how to proceed.
In the meantime, here are some profiling results:
Logging into perlmutter:
I get the following profile: (Note: I wasn't sure how to launch
snakeviz
on perlmutter, so I copied thefastspec.prof
file to my laptop. Also, I'm ignoring the I/O at the moment because those steps are slow for a single object but should be a small fraction of the total time when fitting a full healpixel):