GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
27 stars 3 forks source link

Moffat #49

Closed jecampagne closed 1 year ago

jecampagne commented 1 year ago

This is an import of the code I have developed in my original fork of the repo. I have commented the _drawKImage method using my draw_by_kValue implementattion Now, I have introduced bessel.py, integrate.py and interpolate.py

from jax_galsim.bessel import J0
from jax_galsim.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.interpolate import InterpolatedUnivariateSpline

I have tried to mimic Galsim code both for Truncated and Non-Truncated Moffat but see Galsim issues #1208 #1209

jecampagne commented 1 year ago

Hi, Looking at the code I realized that I do not use InterpolatedUnivariateSpline (and so interpolate.py too) as in the _kvalue_trunc fucntion I use directly the self._hankel(k) function computed with integration method. This way it may trigger some cleaning when the code would be validated.

ismael-mendoza commented 1 year ago

@jecampagne @EiffL just finished cleaning the files so no diffs with main except for JE new modifications. @jecampagne you should fetch and pull when you get a chance.

jecampagne commented 1 year ago

Hi, I have proceeded to some re-coding to match the proposed requierements.

jecampagne commented 1 year ago

Ha, I have found that I should revisit the j0 computation due to the presence of while_loop. I am investigating the use of a fori_loop with "fixed" bounds. Sounds good... The fori_loop was ok but it gives a NaN when computing the jacrev at z=0. In fact I finally use a simple polynomial ratio

@jax.jit
def R(z, num,denom):
  return jnp.polyval(num,z)/jnp.polyval(denom,z)

which is a bit slower that the boost poly-rational algo, but as the merit to be jacrev compliant and get the right numerical answer at z=0 This will be implemented in a next commit.

jecampagne commented 1 year ago

Well, I have a jit-pb If I do:

>>> import jax_galsim as galsim
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>> psf = galsim.Moffat(beta=5.0, flux=0.2, half_light_radius=1.0)
>>> psf
galsim.Moffat(beta=5.0, scale_radius=Array(2.2989595, dtype=float32, weak_type=True), trunc=0.0, flux=0.2, gsparams=galsim.GSParams(128,8192,0.005,5,0.001,1e-05,1e-05,1,0.0001,1e-06,1e-06,1e-08,1e-05))

That's Ok

Now:

>>> import jax
>>> identity = jax.jit(lambda x: x)
>>> identity(galsim.Moffat(beta=5.0, flux=0.2, half_light_radius=1.0))

leads to finally

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 981, in tree_unflatten
    return cls(**(children[0]), **aux_data)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/moffat.py", line 103, in __init__
    if trunc == 0.0 and beta <= _beta_thr:
  File "/pbs/home/c/campagne/my_sps/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/errors.py", line 179, in __init__
    "Abstract tracer value encountered where concrete value is expected: "
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 981, in tree_unflatten
    return cls(**(children[0]), **aux_data)
  File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/moffat.py", line 115, in __init__
    raise _galsim.GalSimIncompatibleValuesError(
galsim.errors.GalSimIncompatibleValuesError: Only one of scale_radius, half_light_radius, or fwhm may be specified Values {'half_light_radius': False, 'scale_radius': False, 'fwhm': False}

and seems to be related to

if trunc == 0.0 and beta <= _beta_thr:

So the __init__ method must be revisited to integrate the argument diagnostics and variable initialization.

jecampagne commented 1 year ago

Ok. So concerning the jitting of Moffat, one way could be to overload

@classmethod
def tree_unflatten(cls, aux_data, children):
   """Recreates an instance of the class from flatten representation"""
    obj = object.__new__(cls)
    obj.xxxx = ...
    return obj

The concern are then the obj.xxxx = ... statements.

When you look at the _init_ method, one needs beta, trunc, flux, gsparams and one-only-one of the following argument not None (scale_radius, half_light_radius, fwhm). Then one computes self._r0 (aka scale_radius) and self._hlr (aka half_light_radius) and self._fwhm (fwhm) according to the argument chosen among of the three above mentioned choice. Then, once these parameters are step, the registration is triggered via

super().__init__(
            beta=beta,
            scale_radius=self._r0,
            half_light_radius=self._hlr,
            fwhm=self._fwhm,
            trunc=trunc,
            flux=flux,
            gsparams=gsparams,
        )

The point is after this registration there is

and there are some other variables computed: self._maxRrD_sq, self._norm, self._knorm, self._knorm_bis, self._r0_sq , self._inv_r0_sq, self._maxk0, self._stepk0.

So I was wandering how to manage the tree_unflatten mechanism with the different params/gsparams feetures. Eg. should I trigger the registration super().__init__ with all the variables computed of the form self.xyz?

Well, I probably need to reconsider completely the initialisation and registration , and also use of property-like methods for some deduced variables.

jecampagne commented 1 year ago

I've completely rewritten the __init__ function following the gaussian-case & w/o tree_unflatten overloading. Then, I have written a series of property. Some questions arise concerning the use of jax.lax.select( ,...) for example. Moreover some Moffat.init arguments cross-checking have been dropped for the time beeing.

The current version is passing the test jitting/vmapping

python -m pytest tests/jax/test_vmapping.py::test_moffat_vmapping
.                                                                                                         [100%]
1 passed in 1.26s
 python -m pytest tests/jax/test_jitting.py::test_moffat_jitting
.                                                                                                         [100%]
1 passed in 1.04s

but not sure that 1) if these tests are strong enough , 2) what about jacrev is a question that remains opened.

jecampagne commented 1 year ago

Well there are errors concerning the current Moffat implementation dealing with maxK when the pytest enter in the drawing phase with convolution. Now, I do not understand why it the pytest seems to indicated that the "truncated Moffat" is triggered while trunc==0. And then there are JAX errors...

jecampagne commented 1 year ago

Ok. I made some progress concerning making Moffat working with pytest testing as far as there are triggered so far.

image

as well as test_jitting & test_vmapping.

But, this is not the end of the story as I have some discrepancy for kValue(PositionD(0.1,-0.2)) computation comparisons wrt the current (ie avalaible when cloning JAX-Galsim) Galsim.Moffat implementation considering the "Truncated version of Moffat". I will investigate...

The "Untruncated Moffat" is ok as fas as I checked, and can be used for further testing (eg. some shear parameters inference)

ismael-mendoza commented 1 year ago

@jecampagne you should go into the tests/galsim_tests_config.yaml and remove the line in allowed_failures that says

"module 'jax_galsim' has no attribute 'Moffat'"

as we have now implemented this. Then you should add the line:

AttributeError: module 'jax_galsim' has no attribute 'Deconvolve'

as we have not yet implemented this function.

Finally, in the enabled_tests at the top, you should add a line:

- test_moffat.py

which might also enable additional tests. Let us know what you find

jecampagne commented 1 year ago

Hello, well the Moffat Trunctated part is not working and makes also the Untruncated Moffat failure (certainly due to a lax.cond that is automatically transformed on lax.select which then compile both types of Moffat functions in _kValue(self, kpos) ) when trying to draw as one can see in the nb on Colab . In the past I was playing with Moffat truncated or untruncated quite nicely see here So it is a matter of JAX technology....

In fact in the past at the initialisation phase, le maxk, stepk, and all the function to specialise the truncated vs untruncated form was stored on self. variables eg to point on the right function (ie. self._kV=self._kValue_untrunc) and then used by

def _kValue(self, kpos):
        return self._kV(kpos)

Now, the switch is performed as

 def _kValue(self, kpos):
        return jax.lax.cond(
            self.trunc > 0,
            lambda x: self._kvalue_trunc(x),
            lambda x: self._kValue_untrunc(x),
            kpos,
        )

I do not know if it is the right thing to do, and if a similar mechanism to the old one can be done using a cashed workspace?

jecampagne commented 1 year ago

With the current version Moffat has a problem with pytesting as the call of psf.kValue with a single galsim.PositionD is not allowed.

image

1) this single PositionD(kx,ky) call is not done for instance in case of drawing an image 2) if one wants to compare with GalSim I use

psf_beta = 5.
scale_radius = 2.3

_psf = _galsim.Moffat(beta=psf_beta, flux=1., scale_radius=scale_radius)
print("(3) Galsim psf:",_psf)

print(_psf.beta, _psf.scale_radius, _psf.trunc, _psf.half_light_radius, _psf.fwhm, _psf.maxk, _psf.stepk)
print("xValue: ",_psf._xValue(_galsim.PositionD(0.1,-0.2)))
_kval = _psf._kValue(_galsim.PositionD(0.1,0.2))
assert _kval.imag == 0., "curious.... _kValue with imaginary part"
print("kValue: ",_kval.real)
print("---------------")
psf = galsim.Moffat(beta=psf_beta, flux=1., scale_radius=scale_radius)
print("JaxGalsim psf:",psf)
print(psf.beta, psf.scale_radius, psf.trunc, psf.half_light_radius, psf.fwhm, psf.maxk, psf.stepk)
print("xValue: ",psf.xValue(galsim.PositionD(0.1,-0.2)))
coords = jnp.array([[[ 0.1, -0.2]]])
kval = jax.vmap(lambda *args : psf.kValue(galsim.PositionD(*args)))(coords[...,0],coords[...,1])
print("kValue: ",kval[0,0])

Now, it is certainly a matter of tweaking the JAX-GalSIm:Moffat code to accept the GalSim test...

One possibility but not so "estheatic" (find a more elegant way) could be to tune _kValue_trunc & _kValue_untrunc argument to keep track of dimension, then transform into atleast_1d jax numpy array, and return according to original input structure

def f(x):
  nd = jnp.ndim(x)
  x = jnp.atleast_1d(x)
  res = x**3
  if nd==0: 
    return res[0]
  else:
    return res
xin = 12.; print(xin, " res=",f(xin))
xin = jnp.array([12.]); print(xin, " res=",f(xin))

leading to

12.0  res= 1728.0
[12.]  res= [1728.]

and it would work with jacrev

print(jax.jacrev(f)(-12.), 3*12.**2)

But this is just to be ok to use GalSim test... Ok???? or we design a JAX specific test. Notice that the xValue does not suffer of the same pb.

jecampagne commented 1 year ago

I will explain the maxk computation philosophy in Galsim using Gaussian 2D profile, and then discuss Moffat (untruncated). So, let us start with a Gaussian profile with isotropic invariance: $$f(r)=C \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)$$ If we define the 2D integrale of $f(r)$ as the flux (noted $F\infty$ to illustrate that $r$ is untruncated), then $$f(r) = \frac{F\infty}{2\pi\sigma^2}\times \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)$$ This is exactly what GalSim computes as xValue (see SBGaussian.cpp). Now, concerning the Fourier 2D representation, one is questioning the definition of the Fourier Transform used by GalSim (always a nightmare with constant definitions!). Well, it turns out that the definition is the so-called (a,b)=(1,-1) in the Mathematica wording, and leads $$f(k) = \frac{F\infty}{2\pi\sigma^2}\times \int_0^\infty J0(kr) \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)\ 2\pi r dr = F\infty \exp(-\frac{1}{2}(k \sigma)^2) $$ and this also what as kValue in the same cpp file as above.

Concerning the "maxk" computation, it turns out that in practice one considers the ratio $f(k)/F_\infty$ and solve the equation in $k$ such the value of the ratio matches a threshold (ie gsparams.maxk_threshold). Let call the maxk_threshold as $gsk$ then $$f(k{max})/F_\infty = gs_k$$ which leads to the GalSim expression for the Gaussian profile.

So far we have all definitions in hand to address the case of the untruncted Moffat profile. If one defines $$f(r) = C \left( 1 + (\frac{r}{rd})^2\right)^{-\beta}$$ then in the same spirit as for the Gaussian profile, the definition of the flux gives the $C$ expression as $$C = \frac{F\infty}{\pi rd^2}(\beta -1)$$ which is exactly what GalSim uses to compute xValue (see SBMoffat.cpp). Now, for k-space expression, following the Fourier Transform definition defined above, leads to $$f(k) = 4 F\infty \frac{\tilde{k}^{\beta-1}}{2^\beta} \frac{K[\beta-1,\tilde{k}]}{\Gamma[\beta-1]}$$ with $\tilde{k} = k\times rd$ (no unit variable) and $K$ Modified Bessel 2nd kind. Which leads to the following assymptotic expression when $k\rightarrow \infty$: $$f(k) \approx F\infty \frac{\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-3/2} \exp(-\tilde{k})$$ To find maxk, then one needs to solve the equation $$gsk = \frac{\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-3/2} \exp(-\tilde{k})$$ leading to the transcendental equation $$x=-\log\alpha + (\beta -3/2)\log(x/2)$$ with $x$ stands for $\tilde{k}$ and $$\alpha = \frac{gs{k} \Gamma(\beta-1)}{\sqrt{\pi}}$$. We can solve this by iteration starting with $x=-\log\alpha$.

So far so good, BUT GalSim uses still a wrong asymptotic expression which gives $$gs_k = \frac{2\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-1/2} \exp(-\tilde{k})$$ which then turns to solve $$x=-\log a + (\beta -1/2)\log x$$ with $$a = \frac{gs_k \Gamma[\beta-1]}{2\sqrt{\pi}} 2^{\beta-1/2}$$

So, for the time beeing I have coded the GalSim way to compute $maxK$ to get the GalSim vs JAX-Galsim comparison as close as possible but I've also kept commented what I would consider the right computations and rise an issue in the GalSim repository see #1208. @rmjarvis has some arguments to keep the current way to compute maxk.

rmjarvis commented 1 year ago

FYI, #1208 has already been fixed in the main branch of GalSim. cf. https://github.com/GalSim-developers/GalSim/pull/1210

rmjarvis commented 1 year ago

But note that we don't use the corrected asymptotic expansion either. It turns out that neither formula actually gives very good results, so we just use the full calculation to compute maxk without approximation.

jecampagne commented 1 year ago

What I would like to propose also is the following: use the self._workspace = {} mechanism used in jax_cosmo/core.py:Cosmology to store the maxk and stepk. Eg.

     # in __init__
    .....
     self._ws = {} # not store in params nor gsparams 
   ......

    # then use self._ws to store at first use the value of maxk
    @property
    def _maxk(self):
        if "_maxk" not in self._ws:
             res = jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc)
            self._ws["_mask"] = res
       return self._ws["_mask"]

What do you think @EiffL ? I have made a demo which passes our jitting and vmapping tests.

jecampagne commented 1 year ago

ho! sorry I miss @rmjarvis comment, thanks. So does the new computation is activated when we use for JAX-GalSim installation : install_requires "galsim >= 2.3.0"?

jecampagne commented 1 year ago

One add: looking at Galsim/releases/2.4 test_moffat.py file, it seems that "maxk" numerical result is only tested in the "truncated case". It is right? Would not it valuable to be extended to un_truncated case too?

rmjarvis commented 1 year ago

install_requires "galsim >= 2.3.0"

No. It will be released with 2.5.0, but we're not quite ready for that yet. (The current version is 2.4.11.)

jecampagne commented 1 year ago

Hi @rmjarvis

May be this msg would be more adequat to GalSim repo (let me know) but looking to a possible evolution of the Moffat code for JAX-GAlsim, I was looking to the SBMoffat.cpp code in the GalSim main branch. I am curious

// Set maxK to the value where the FT is down to maxk_threshold
    double SBMoffat::SBMoffatImpl::maxK() const
    {
        if (_maxk == 0.) {
            if (_trunc == 0.) {
                MoffatMaxKSolver func(this, this->gsparams.maxk_threshold);
                double ksq1 = 0.;
                double ksq2 = 100;  // k=10 is usually close, so it makes a good starting guess.
                Solve<MoffatMaxKSolver> solver(func,ksq1,ksq2);
....

do not you feel that the ksq2=100 (ie upper bound of k is 10) is a bit low? ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.

rmjarvis commented 1 year ago

ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.

Did you try it? I'm pretty sure it gets that right.

jecampagne commented 1 year ago

ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.

Did you try it? I'm pretty sure it gets that right.

I hopefully manage to get running the main-branch Galsim and so you're right the code find maxK=14.708. So I guess the Finding algo to push the upperbound further is doing the job.

We can close the thread. Thanks @rmjarvis

ismael-mendoza commented 1 year ago

@jecampagne @EiffL I'm merging now to avoid deprecation, I think Moffat passes almost all tests in test_moffat.py. I will make an issue to document tests that can still be activated