LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
52 stars 6 forks source link

Basis in the long term. #131

Closed Jordan-Dennis closed 1 year ago

Jordan-Dennis commented 1 year ago

Hi all, There has been a considerable amount of discussion recently about the Basis. This issue is going to be very long so don't try and read it all at once.

Unless certain assumptions are made about the aperture (often but not always true) the basis must be re-evaluated at runtime. The assumptions that I am referring to are the assumptions of the CircularAperture that is only aperture currently on main. This assumes that the edges of the aperture touch the edges of the wavefront perfectly. It works with the zernike_basis on main also which is a direct port of the POPPY implementation (cannot be jit compiled). This reduces the amount of information that we need to know about the system in order to generate the basis to just the number of pixels.

However, I think (this is an opinion that is not particularly well informed) that we need to have a more general routine for generating the zernike basis that can work on components that do not fit so neatly over the wavefront. I have spent a large amount of time working on this and have produced working code with a fair amount of help from @ataras2. The problem is that it is recomputed every pass, which depending on the resolution and the number of terms can take ~100ms to do. Unless we are learning large structural changes in the aperture like I have done with the HST I do not think that this re-computation is necessary. @benjaminpope agrees and is actually much more gun-ho going so far as to say that we don't need to recompute in )almost every)/every use case.

Regardless the code @ataras2 and I have produced can be jit compiled making it GPU friendly which should hopefully reduce the run-time since it is hand vectorized. I have also attempted to cache the basis vectors so that they are calculated on the first pass and then never again. Unfortunately jax/equinox does not permit the use of the @cached_property decorator from functools so the simple option cannot be used. The next attempt I made attempted to get around the immutability enforced by equinox by accessing the fields via the .__dict__ attribute. While this works it cannot be performed from within a traced function making it useless in this scenario.

At this point we had a meeting and discussed just having a hard-coded array of a finite number of the Zernike terms. This lead to the creation of a table of hard coded Zernike lambda functions. I am in the process of generalising this to write a function that takes a noll index and returns a jax.tree_util.Partial representing the corresponding Zernike. These "functional Zernikes" have the advantage of faster computation because the coefficients do not need to be calculated ahead of time. According to some early profiling done by @ataras2 and I this resulted in a tenfold increase in computation running on a GPU.

The orthonormalisation code does no harm and I would like to keep it in the off chance that it is useful. I think it has potential to improve the accuracy and if it is only the orthonormalisation that is happening without having to generate the Zernike from scratch each time then it is not actually too slow taking ~1ms to complete even for reasonably high resolutions. This was again done using @atras2's GPU and the performance is considerably worse on CPU.

So, now that you know the history and state of things at the moment, let us discuss what remains to be clarified.

benjaminpope commented 1 year ago

The mirrors never shear significantly in JWST, so we don't need to worry about that. Hexikes are just coordinate remapped zernikes, so you can use your existing table of zernike polynomials and have a layer that pre-warps the coordinates.

I don't understand the 100ms. Base numpy can calculate a polynomial of a 1024*1024 meshgrid in ~ 1ms. It should be faster with jax and tree map.

benjaminpope commented 1 year ago

I suppose we could get up to 100ms if we evaluate this many times to make a very large array of zernikes?

LouisDesdoigts commented 1 year ago

So all of this functionality can just be put into an appropriate file in the utils module, and I imagine most of this will want to live there anyway.

There are no real world problems we will encounter that will have apertures shapes changing anywhere near enough to warrant an a re-orthonormalisation.

I think we want two ways of interacting with these classes: Pre-calc'd zernike basis and run-time zernike basis. Hard-coding the polynomial functions is fine and they can be loaded into a list or dictionary and jax.tree_map can be used to perform the operation using vmap.

LouisDesdoigts commented 1 year ago

I suspect we want these different functionalities to exist as different classes to prevent classes being clogged up with extra parameters

Jordan-Dennis commented 1 year ago

@benjaminpope the ~100ms was a typo on my part. In the slack you will see @ataras2 and I timed it for the square aperture and it was ~10ms (7ms).

benjaminpope commented 1 year ago

Oh good. That is pretty reasonable. Anything is fine for a dynamic feature of dLux so long as it takes ~ 10x less time than a typical Fourier transform.

Jordan-Dennis commented 1 year ago

I think the structure that I might implement will be something like.

class HardCodedLowOrderZernikeBasis(dl.OpticalLayer): # Very Fast
    def __init__(self, noll_indices: list, coeffs: list):
    def __call__(self, params: dict) -> dict:

class ArbitraryZernikeBasis(dl.OpticalLayer): 
    def __init__(self, noll_indices: list, coeffs: list):
    def __call__(self, params: dict) -> dict:

class HardCodedLowOrderHexikeBasis(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list):
    def __call__(self, params: dict) -> dict:

class ArbitraryOrthonormalBasisAndAperture(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list, aperture: Aperture): 
    def __call__(self, params: dict) -> dict:

class ZernikeBasisAndCircularAperture(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list, aperture: Aperture):
    def __call__(self, params: dict) -> dict:

This is still just for single apertures. I haven't decided how I will cover compound apertures yet.

Edit: When I include the AndAperture flag it means that both the Aperture and the Basis are applied in the same layer. This should be slightly more efficient because then the aperture array only needs to be generated once and used twice.

Does this structure look any good?

Jordan-Dennis commented 1 year ago

For the case of JWST and Heimdallar (spelled how it sounds) we need to concern ourselves with a CompoundAperture with Basis terms generated for each Aperture. While this can be done by specifying them all individually my old code was considerably nicer in this regard. It seems almost like it is worth having:

class CompoundCircularApertureAndBasis(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
    def __call__(self, params: dict) -> dict:

class CompoundHexagonalApertureAndBasis(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
    def __call__(self, params: dict) -> dict:

class CompoundArbitraryApertureAndBasis(dl.OpticalLayer):
    def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
    def __call__(self, params: dict) -> dict:

But I am not so sure if this is a good set-up. It should be fine in the short term I think. The main concern is that there is no way to get the Basis so that it is independent from the aperture. This may not even be a bad thing.

Jordan-Dennis commented 1 year ago

Having done some work I think it is best if the Zernike is always associated with an aperture. It simplifies things considerably.

LouisDesdoigts commented 1 year ago

So we are going to want to do this with inheritance. We should be able to have base classes for:

  1. Static pre-calculated apertures
  2. Static pre-calculated and stored basis (for aperture shapes inherited by the aperture)
  3. Dynamically generated aperture shapes
  4. Dynamically generated basis on the aperture shape

The final classes should exist primarily as lightweight wrappers that inherit from the combination of these.

LouisDesdoigts commented 1 year ago

They should also be able to be instantiated by themselves though!

Jordan-Dennis commented 1 year ago

So it is difficult to pre-calculate the apertures. Without having access to the coordinate system stored in the wavefront there is really no way of generating an arbitrary aperture at the initialisation. @ataras2 had some interesting thoughts regarding lifting the coordinate information into the dl.Telescope/dl.Optics. His idea was that dl.Optics store an instance of dl.Wavefront and that dl.CreateWavefront be demoted from a layer to a function that could be called in the __init__ method of dl.Optics. This would be resolved and then the list of layers could be sorted and the apertures statically generated. I think it sounds interesting, but it is likely a lot of work.

I know that there are cases where apertures can be pre-generated. However, in these scenarios one has to make very stringent assumptions that are not always true. For example, the dl.CircularAperture that is currently on main assumes that it is directly in the centre of the wavefront and spans the entire wavefront. While this works for dl.AngularMFT and dl.PhysicalMFT type propagators and is a common occurance (e.g. simple TOLIMAN and HST models) it very quickly fails (e.g. complex HST and Heimdellar).

@LouisDesdoigts what are your thoughts on @ataras2's idea?

LouisDesdoigts commented 1 year ago

So just have a pixel scale and npix parameter that is fed to the layer at init time to do the pre-calculation, but don't store these parameters. These can just be the same values that are fed into create wavefront.

As for moving the CreateWavefront layer into Optics it's definitely possible but I don't think this is what we want in the long run as it disjoints where the parameters that define the optics are stored (some in layers and some in the optics). We already have access to pixel scale and npix params when initialising the layers so no assumptions need to be made!

Jordan-Dennis commented 1 year ago

Yeah OK. I guess that is easiest and requires the least amount of work to implement. @ataras2 and I just posted a large UML in the slack detailing our plan for this stuff. We'll need to revise it now but it should still be in general true.