spacetelescope / webbpsf

James Webb Space Telescope PSF simulation tool
https://webbpsf.readthedocs.io
BSD 3-Clause "New" or "Revised" License
112 stars 61 forks source link

WebbPSF's distortion model is really slow to compute - we should improve it #426

Closed mperrin closed 3 years ago

mperrin commented 3 years ago

It's always been the case since we added the distortion model that it's a slow part of the calculation, but I've only just recently realized how slow. For typical use cases it's 75-95% of the computation time, totally dominating over the actual optical propagation.

I find myself often disabling the distortion model (add_distortion=False) in most of my own calculations because it's so painfully slow... I'd like to improve on this.

Timing results first, then ideas on what to improve.


**_Timing results_:** Some detailed timing results using `line_profiler` in Jupyter: ``` %load_ext line_profiler nrc = webbpsf.NIRCam() %lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf(nlambda=1) ``` That will output a detailed printout of the runtime for calc_psf, including the fraction of time in each line of that function. The relevant one to look for is `self._calc_psf_format_output(result, local_options)`; the distortion model is called within that to compute the images for the additional output extensions. In the above case of a monochromatic calculation, that line takes **95% of the time**, 6.7 seconds out of total runtime 7 seconds. The actual optical propagation is just 0.3 s in comparison! Compare to without the distortion model: ``` %lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf(nlambda=1, add_distortion=False) ``` It takes 32% (.17 s) out of a total runtime of 0.5 s to format the output (in this case it's just binning down to the detector pixel scale, without preparing the distortion model). Similar for a broadband calculation with 20 wavelengths: ``` %lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf() ``` The distortion model again takes 6.8 s out of a total runtime of 9.3 s. Using the call profiler `%prun` confirms that it's the ndgriddata function call that's taking up nearly all that time. (See e.g. output of `%prun psf = nrc.calc_psf(nlambda=1)`
**_What might we do to improve this_:** The current distortion model is almost certainly overkill: we use the SIAF distortion polynomials to work out the coordinate transforms on a pixel-by-pixel basis from the undistorted 'ideal' frame to the detector-coordinates 'science' frame, then use scipy.interpolate.griddata to warp the PSF into the distorted frame. Doing this on a pixel-by-pixel basis is extremely generalized but likely not necessary for the modest size of PSF postage stamps. I expect we could get very, very close to the same results just using a simple linear transformation (e.g rotation + skew matrix) which can be implemented in a much more performant way using for instance the scipy affine_transform function. Given the relatively small geometric distortion in the JWST instruments, I expect this linear approximation model will be adequate for our purposes, and would remove what is currently the most substantial bottleneck in JWST PSF calculations. I don't think it makes sense to continue spending 70+% of calculation runtime on what is a relatively minor effect for most use cases of PSF modeling. (though still important enough we do want to include it at some level) @shanosborne @grbrady @obi-wan76 @Skyhawk172 I'm curious to hear if you have any thoughts on the above?
grbrady commented 3 years ago

Implementing a generalized distortion models using griddata is typically very slow, so this is unsurprising. The Delauny triangulation at the heart of griddata is not particularly parallelizable by my understanding, as well.

Primary distortion is proportional to the cube of the field angle, and then there are smaller higher order terms that most real systems have. To implement a single distortion model that captures the entire field it would be necessary to use something like griddata. However over smaller subfields I think it could be plausible that a linear model of distortion would be sufficient, depending on the fidelity needed for the task at hand. The affine transform approach should capture this well. One could also consider breaking the image up in to subfields and apply linear models with the needed fidelity, locally on each, and then assemble a larger image that captures some of the higher order dependence. This approach is highly parallelizable.

An alternate approach, that also probably isn't very efficient, is to compute PSFs over subfields and include the (field varying) linear phase term (which is how distortion manifests itself in the pupil). Then, individual PSF subfields can be assembled into a wider field image, with the local shifts of each giving the distortion. Using the matrix DFT these subfields can be arbitrarily small, but the field-dependent OPD needs to be calculated individually for each subfield. I'm pretty sure this isn't the way we should go with this (probably super slow), but I thought I'd mention it to see if it gave anyone any ideas. In other words, maybe there's something we can do in the pupil plane instead of the image plane.

Cheers, Greg

On Wed, Mar 3, 2021 at 9:31 AM Marshall Perrin notifications@github.com wrote:

It's always been the case since we added the distortion model that it's a slow part of the calculation, but I've only just recently realized how slow. For typical use cases it's 75-95% of the computation time, totally dominating over the actual optical propagation.

I find myself often disabling the distortion model (add_distortion=False) in most of my own calculations because it's so painfully slow... I'd like to improve on this.

Timing results first, then ideas on what to improve.

Timing results:

Some detailed timing results using line_profiler in Jupyter:

%load_ext line_profiler

nrc = webbpsf.NIRCam()

%lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf(nlambda=1)

That will output a detailed printout of the runtime for calc_psf, including the fraction of time in each line of that function. The relevant one to look for is self._calc_psf_format_output(result, local_options); the distortion model is called within that to compute the images for the additional output extensions.

In the above case of a monochromatic calculation, that line takes 95% of the time, 6.7 seconds out of total runtime 7 seconds. The actual optical propagation is just 0.3 s in comparison!

Compare to without the distortion model:

%lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf(nlambda=1, add_distortion=False)

It takes 32% (.17 s) out of a total runtime of 0.5 s to format the output (in this case it's just binning down to the detector pixel scale, without preparing the distortion model).

Similar for a broadband calculation with 20 wavelengths:

%lprun -f webbpsf.webbpsf_core.SpaceTelescopeInstrument.calc_psf psf = nrc.calc_psf()

The distortion model again takes 6.8 s out of a total runtime of 9.3 s.

Using the call profiler %prun confirms that it's the ndgriddata function call that's taking up nearly all that time. (See e.g. output of %prun psf = nrc.calc_psf(nlambda=1)

What might we do to improve this:

The current distortion model is almost certainly overkill: we use the SIAF distortion polynomials to work out the coordinate transforms on a pixel-by-pixel basis from the undistorted 'ideal' frame to the detector-coordinates 'science' frame, then use scipy.interpolate.griddata to warp the PSF into the distorted frame.

Doing this on a pixel-by-pixel basis is extremely generalized but likely not necessary for the modest size of PSF postage stamps. I expect we could get very, very close to the same results just using a simple linear transformation (e.g rotation + skew matrix) which can be implemented in a much more performant way using for instance the scipy affine_transform function.

Given the relatively small geometric distortion in the JWST instruments, I expect this linear approximation model will be adequate for our purposes, and would remove what is currently the most substantial bottleneck in JWST PSF calculations. I don't think it makes sense to continue spending 70+% of calculation runtime on what is a relatively minor effect for most use cases of PSF modeling. (though still important enough we do want to include it at some level)

@shanosborne https://github.com/shanosborne @grbrady https://github.com/grbrady @obi-wan76 https://github.com/obi-wan76 @Skyhawk172 https://github.com/Skyhawk172 I'm curious to hear if you have any thoughts on the above?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/spacetelescope/webbpsf/issues/426, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGKZEMKBO5QWZR6UHS26LL3TBZB2TANCNFSM4YRKXTNA .

JarronL commented 3 years ago

Right now, the add_distortion function uses griddata to transform from an irregularly gridded 'sci' coordinate array to regularly gridded 'sci' coordinate array. Instead, it would be much faster to use RegularGridInterpolator in place of griddata. To do this, we could start from the already regular 'idl' coordinates and interpolate onto the set of irregular 'idl' points that map to our desired 'sci' coord values.

Some basic code:

xlin = np.linspace(-1*(nx-1.)/2, (nx-1.)/2, nx)
ylin = np.linspace(-1*(ny-1.)/2, (ny-1.)/2, ny)
xidl = xlin * pixelscale
yidl = ylin * pixelscale
# Create regular grid interpolator function in terms of 'idl' coordinates
func = RegularGridInterpolator((yidl,xidl), psf, bounds_error=False, fill_value=None)

# Create an x,y grid of 'sci' coords
xarr, yarr = np.meshgrid(xlin, ylin)
xsci_ref, ysci_ref = aper.reference_point('sci')
# Divide by oversample since 'sci' are ~detector pixel sizes
xnew_sci = xarr / oversamp + xsci_ref
ynew_sci = yarr / oversamp + ysci_ref

# Grab 'idl' coordinates that map to the 'sci' coords we care about
xnew_idl, ynew_idl = aper.sci_to_idl(xnew_sci, ynew_sci)

# Combine all points into a single array
pts = np.array([ynew_idl.flatten(),xnew_idl.flatten()]).transpose()

# Evaluate function at all requested points and reshape to image size
psf_new = func(pts).reshape([ny,nx])

I ran a few quick tests on a single 256x256 PSF. This code snippet takes about 10 msec, compared to 800 msec for the griddata call. Image results were very similar, but I haven't yet tested for large field angles.

I will try to integrate this into WebbPSF tomorrow and run a few different PSF calculations to compare final distortion results and timing benchmarks.

JarronL commented 3 years ago

As @mperrin mentioned, we could also use scipy's affine_transform function. If anyone knows how to convert pysiaf's polynomial coefficients into an affine transformation matrix, I could quickly write up a similar function and compare the results.