brandondube / prysm

physical optics: integrated modeling, phase retrieval, segmented systems, polynomials and fitting, sequential raytracing...
https://prysm.readthedocs.io/en/stable/
MIT License
259 stars 44 forks source link

Feature request: Unit support in prysm #116

Open egemenimre opened 2 days ago

egemenimre commented 2 days ago

Currently the units are not supported like astropy.units or pint. In particular pint offers a @wraps decorator that forces certain units into lower level (and computationally intensive) functions, so that you just have the conversion overhead at the beginning and at the end, where high level functions always accept Quantity objects.

pint handles the units nicely, so that you don't mulitply with 1e3 anywhere. you just do dx.as(u.um) to convert the microns versions of the whatever length unit dx has. It also has numpy support, so most stuff works out of the box as long as you use numpy methods rather than built-in math methods.

I believe prysm would benefit immensely from this sort of units support. If you need help, I will gladly do some work on it.

As an example:

from pint import UnitRegistry

# Init units
u = UnitRegistry()

# aperture diameter and radius
ap_diam = 650 * u.mm  # defined with units
aperture_radius = aperture_diam / 2.0

# generate the square grid in polar coords
# r should be automatically  Quantity objects in u.mm and t in u.rad
r, t = _generate_grid(aperture_diam, samples)

# generate aperture (circle mask applied to the square grid)
# accepts the aperture radius in Quantity objects
aperture = circle(aperture_radius, r),
brandondube commented 1 day ago

Thank you for interest in this. I am happy to review a pull request for this, but would temper that anything that adds units has to use the current set of units (that has been the same for ~10 years) when no unit is provided, can't impact performance, and can't break the flexible backend system - so whatever units library used has to meet those objectives.

egemenimre commented 1 day ago

OK, thanks for the answer, I'll see what I can do. :)

egemenimre commented 20 hours ago

After a full day of poking and converting, I gave up. I can't do the changes without breaking the API, as I have to insert with_units = False in a lot of places, and certain things are @property, where I can't even do that.

I still would like to propose some changes to get rid of some active blockers for the units - mostly because math does not support units and numpy does, so a simple sin(alpha) could break things whereas np.sin(alpha) does not.

You can close the issue if you wish.

brandondube commented 20 hours ago

I use stdlib math very little within prysm, mostly where it is clearer vs the alternative. There is a "three color" problem sometimes, where prysm.mathops.np is really CuPy, but you are going to do an operation on for example a 3x3 array, or a scalar where it is preferrable to do it on CPU. When it is definitely a scalar, I have used math. Where it is small arrays, I do import numpy as truenp, which confuses people who are not familiar with the slight of hand that mathops.np does.

Perhaps I am missing something, but is there a reason you couldn't do:

# ...
ret_qty = True
if not isinstance(val, u.Quantity):
    val = val * u.mm # whatever unit prysm uses in this location
    ret_qty = False

# ...
out = 5 * u.mm # as example
if not ret_qty:
    out = out.to(u.mm) # unravel it to underlying data

return out
egemenimre commented 8 hours ago

Let's go with a concrete example. I deleted my changes in a fit of rage, so I'll show them with no checks - there may be minor errors but I assure you I made them work last time. :)

I started out with the first diffraction model.

Quite a bit stuff is supported out of the box, thanks to numpy.

xi, eta = make_xy_grid(256, diameter=10 * u.mm)
r, t = cart_to_polar(xi, eta)

A = circle(5 * u.mm, r)

# skip the hopkins etc.

wf = Wavefront.from_amp_and_phase(A, None, HeNe * u.um, dx, with_units=with_units)

The first challenge is the Wavefront object. I can decorate it with @u.wraps, no probs.

class Wavefront:
    """(Complex) representation of a wavefront."""

    # u.wraps format: (return_type, (self, cmplx_field, wavelength, dx, space, with_units), strict_check)
    # self, dimensionless stuff and booleans are all None. Possible Quantity objects have units like u.um
    # if you send a wavelength in float, it will be assumed to be u.um
    @u.wraps(None, (None, None, u.um, None, None, None), False)
    def __init__(self, cmplx_field, wavelength, dx, space='pupil', with_units = False):
        """Create a new Wavefront instance.

        Parameters
        ----------
        cmplx_field : numpy.ndarray
            complex-valued array with both amplitude and phase error
        wavelength : float
            wavelength of light, microns
        dx : float
            inter-sample spacing, mm (space=pupil) or um (space=psf)
        space : str, {'pupil', 'psf'}
            what sort of space the field occupies

        """
        self.data = cmplx_field

        if with_units: 

            self.wavelength = wavelength * u.um

            # !!!! There is more to it, see text below
            if space == "pupil
                self.dx = dx * u.mm
            else:
                self.dx = dx * u.um
        else:
            self.wavelength = wavelength
            self.dx = dx

        self.space = space

    @classmethod
    @u.wraps(None, (None, None, u.nm, u.um, u.mm, None), False)
    def from_amp_and_phase(cls, amplitude, phase, wavelength, dx, with_units = False):
        """Create a Wavefront from amplitude and phase.

        Parameters
        ----------
        amplitude : numpy.ndarray
            array containing the amplitude
        phase : numpy.ndarray, optional
            array containing the optical path error with units of nm
            if None, assumed zero
        wavelength : float
            wavelength of light with units of microns
        dx : float
            sample spacing with units of mm

        """
        if phase is not None:
            phase_prefix = 1j * 2 * np.pi / wavelength / 1e3  # / 1e3 does nm-to-um for phase on a scalar
            P = amplitude * np.exp(phase_prefix * phase)
        else:
            P = amplitude
        return cls(P, wavelength, dx, with_units=with_units)

So, the @u.wraps decorator accepts floats as-is in the units defined, and converts them to the units defined if they are originally Quantity objects. So you can define a wavelength in km if you wish but it gets converted to um when sent to the method. This does add overhead, but ensures that the inner workings are (usually) untouched.

The with_units=False default thing ensures backward compatibility so all the old code will be blissfully unaware. But the downside is that with every single method you have to remember writing with_units=True in the call. And you forget it. I know I did. Maybe it is possible to set a global config parameter (so you bury the with_units checks inside the methods), but then you may practically force the user to use Quantity objects everywhere.

The third thing is the incompatibility between how you define dx depending on space. Now, the decorator forces a single unit and I can't know what you sent. So I set the decorator to None and I accept what the user sends, and depending on whether a Quantity or a float and checking against the pupil, I was able to circumvent the issue (not shown in the code above, couldn't be bothered to rewrite it for this text). But I lost the decorator support. Not the end of the world, but it did make me scratch my head.

OK, back at the Diffraction Model Notebook. The next step is to focus:

E = wf.focus(100 * u.mm, with_units=with_units)
psf = E.intensity
fno = 10
psf_radius = 1.22*(HeNe * u.um)*fno
psf.plot2d(xlim=psf_radius*10, power=1/3, cmap='gray', interpolation='bicubic')

Next challenge is wf.focus(). It basically calls two functions, focus (which does the FFT stuff, no probs) and then pupil_sample_to_psf_sample, which does changes the sample spacing. It is easy enough to do a @u.wraps and with_units.

In the meantime I did modify the RichData object to accept units as well, also with @u.wraps and with_units. Except it is more tricky, because RichData does not use mm everywhere. In psf space uses microns. When you do psf_to_mtf it becomes cycles/mm. This sort of thing makes life with units a tad more difficult. I did find my ways around it without modifying the library, but would have been way easier if I could have a RichData object where I specified the units for example.

The final straw for me was the Wavefront.intensity. It is a "property", so accepts no flags, yet it does define a RichData object, for which I need to be able to define the with_units flag. This was where I threw in the towel, because I can't make a property into a function call without breaking the code for everyone else.

egemenimre commented 8 hours ago

It was later when my head cooled I realised the global config setting at the very beginning would solve all the issues and keep everything under the hood. For this we may need to work together in a units_support branch. The entire undertaking would be a huge PR, assuming you see some value in all this, now that you understand how pint works with @u.wraps.

egemenimre commented 8 hours ago

Another remark: plot2d xlim doesn't like to be used like this when you have units:

psf.plot2d(xlim=psf_radius*10, power=1/3, cmap='gray', interpolation='bicubic')

The limits are then completely borked, it looks like I see the complete space.

This is what I do elsewhere:

# image limits (with units)
xlim = airy_radius * 2

# note how I define units for the axis labels
psf_mono.plot2d(xlim=(-xlim, xlim), log=False, axis_labels=(xlim.u,xlim.u))

Furthermore, the axis labels here should be units by default (thanks to pint matplotlib support) but you explicitly send (None, None) when nothing is defined, which deletes the units.

egemenimre commented 8 hours ago

Side note regarding Quantity. If you define a parameter with units, say dx = 5 * u.mm, you can't sum it with a float any more. Which makes sense, because what is dx + 5?

On the other hand, division and multiplication works: dx/5 = 1 * u.mm and dx / (5 * u.mm) = 1 * u.dimensionless.

Units handling may need some care, because dx / (5 * u.um) is not 1000 * u.dimensionless but 1 u.mm / u.um. It is not wrong, but not what you expect either. Usually it does not matter, but sometimes you have to use dx.to_reduced_units() to force everything into a single type (here it would be u.dimensionless).