Closed Jordan-Dennis closed 1 year ago
The mirrors never shear significantly in JWST, so we don't need to worry about that. Hexikes are just coordinate remapped zernikes, so you can use your existing table of zernike polynomials and have a layer that pre-warps the coordinates.
I don't understand the 100ms. Base numpy can calculate a polynomial of a 1024*1024 meshgrid in ~ 1ms. It should be faster with jax and tree map.
I suppose we could get up to 100ms if we evaluate this many times to make a very large array of zernikes?
So all of this functionality can just be put into an appropriate file in the utils module, and I imagine most of this will want to live there anyway.
There are no real world problems we will encounter that will have apertures shapes changing anywhere near enough to warrant an a re-orthonormalisation.
I think we want two ways of interacting with these classes: Pre-calc'd zernike basis and run-time zernike basis. Hard-coding the polynomial functions is fine and they can be loaded into a list or dictionary and jax.tree_map
can be used to perform the operation using vmap.
I suspect we want these different functionalities to exist as different classes to prevent classes being clogged up with extra parameters
@benjaminpope the ~100ms
was a typo on my part. In the slack you will see @ataras2 and I timed it for the square aperture and it was ~10ms (7ms)
.
Oh good. That is pretty reasonable. Anything is fine for a dynamic feature of dLux so long as it takes ~ 10x less time than a typical Fourier transform.
I think the structure that I might implement will be something like.
class HardCodedLowOrderZernikeBasis(dl.OpticalLayer): # Very Fast
def __init__(self, noll_indices: list, coeffs: list):
def __call__(self, params: dict) -> dict:
class ArbitraryZernikeBasis(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list):
def __call__(self, params: dict) -> dict:
class HardCodedLowOrderHexikeBasis(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list):
def __call__(self, params: dict) -> dict:
class ArbitraryOrthonormalBasisAndAperture(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list, aperture: Aperture):
def __call__(self, params: dict) -> dict:
class ZernikeBasisAndCircularAperture(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list, aperture: Aperture):
def __call__(self, params: dict) -> dict:
This is still just for single apertures. I haven't decided how I will cover compound apertures yet.
Edit: When I include the AndAperture
flag it means that both the Aperture
and the Basis
are applied in the same layer. This should be slightly more efficient because then the aperture array only needs to be generated once and used twice.
Does this structure look any good?
For the case of JWST and Heimdallar (spelled how it sounds) we need to concern ourselves with a CompoundAperture
with Basis
terms generated for each Aperture
. While this can be done by specifying them all individually my old code was considerably nicer in this regard. It seems almost like it is worth having:
class CompoundCircularApertureAndBasis(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
def __call__(self, params: dict) -> dict:
class CompoundHexagonalApertureAndBasis(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
def __call__(self, params: dict) -> dict:
class CompoundArbitraryApertureAndBasis(dl.OpticalLayer):
def __init__(self, noll_indices: list, coeffs: list, aperture: CompAperture):
def __call__(self, params: dict) -> dict:
But I am not so sure if this is a good set-up. It should be fine in the short term I think. The main concern is that there is no way to get the Basis
so that it is independent from the aperture. This may not even be a bad thing.
Having done some work I think it is best if the Zernike is always associated with an aperture. It simplifies things considerably.
So we are going to want to do this with inheritance. We should be able to have base classes for:
The final classes should exist primarily as lightweight wrappers that inherit from the combination of these.
They should also be able to be instantiated by themselves though!
So it is difficult to pre-calculate the apertures. Without having access to the coordinate system stored in the wavefront there is really no way of generating an arbitrary aperture at the initialisation. @ataras2 had some interesting thoughts regarding lifting the coordinate information into the dl.Telescope
/dl.Optics
. His idea was that dl.Optics
store an instance of dl.Wavefront
and that dl.CreateWavefront
be demoted from a layer to a function that could be called in the __init__
method of dl.Optics
. This would be resolved and then the list of layers could be sorted and the apertures statically generated. I think it sounds interesting, but it is likely a lot of work.
I know that there are cases where apertures can be pre-generated. However, in these scenarios one has to make very stringent assumptions that are not always true. For example, the dl.CircularAperture
that is currently on main assumes that it is directly in the centre of the wavefront and spans the entire wavefront. While this works for dl.AngularMFT
and dl.PhysicalMFT
type propagators and is a common occurance (e.g. simple TOLIMAN and HST models) it very quickly fails (e.g. complex HST and Heimdellar).
@LouisDesdoigts what are your thoughts on @ataras2's idea?
So just have a pixel scale and npix parameter that is fed to the layer at init time to do the pre-calculation, but don't store these parameters. These can just be the same values that are fed into create wavefront.
As for moving the CreateWavefront layer into Optics it's definitely possible but I don't think this is what we want in the long run as it disjoints where the parameters that define the optics are stored (some in layers and some in the optics). We already have access to pixel scale and npix params when initialising the layers so no assumptions need to be made!
Yeah OK. I guess that is easiest and requires the least amount of work to implement. @ataras2 and I just posted a large UML in the slack detailing our plan for this stuff. We'll need to revise it now but it should still be in general true.
Hi all, There has been a considerable amount of discussion recently about the
Basis
. This issue is going to be very long so don't try and read it all at once.Unless certain assumptions are made about the aperture (often but not always true) the basis must be re-evaluated at runtime. The assumptions that I am referring to are the assumptions of the
CircularAperture
that is only aperture currently onmain
. This assumes that the edges of the aperture touch the edges of the wavefront perfectly. It works with thezernike_basis
onmain
also which is a direct port of thePOPPY
implementation (cannot bejit
compiled). This reduces the amount of information that we need to know about the system in order to generate the basis to just the number of pixels.However, I think (this is an opinion that is not particularly well informed) that we need to have a more general routine for generating the zernike basis that can work on components that do not fit so neatly over the wavefront. I have spent a large amount of time working on this and have produced working code with a fair amount of help from @ataras2. The problem is that it is recomputed every pass, which depending on the resolution and the number of terms can take
~100ms
to do. Unless we are learning large structural changes in the aperture like I have done with the HST I do not think that this re-computation is necessary. @benjaminpope agrees and is actually much more gun-ho going so far as to say that we don't need to recompute in )almost every)/every use case.Regardless the code @ataras2 and I have produced can be
jit
compiled making it GPU friendly which should hopefully reduce the run-time since it is hand vectorized. I have also attempted to cache the basis vectors so that they are calculated on the first pass and then never again. Unfortunatelyjax
/equinox
does not permit the use of the@cached_property
decorator fromfunctools
so the simple option cannot be used. The next attempt I made attempted to get around the immutability enforced byequinox
by accessing the fields via the.__dict__
attribute. While this works it cannot be performed from within a traced function making it useless in this scenario.At this point we had a meeting and discussed just having a hard-coded array of a finite number of the Zernike terms. This lead to the creation of a table of hard coded Zernike
lambda
functions. I am in the process of generalising this to write a function that takes a noll index and returns ajax.tree_util.Partial
representing the corresponding Zernike. These "functional Zernikes" have the advantage of faster computation because the coefficients do not need to be calculated ahead of time. According to some early profiling done by @ataras2 and I this resulted in a tenfold increase in computation running on a GPU.The orthonormalisation code does no harm and I would like to keep it in the off chance that it is useful. I think it has potential to improve the accuracy and if it is only the orthonormalisation that is happening without having to generate the Zernike from scratch each time then it is not actually too slow taking
~1ms
to complete even for reasonably high resolutions. This was again done using @atras2's GPU and the performance is considerably worse on CPU.So, now that you know the history and state of things at the moment, let us discuss what remains to be clarified.