Closed jecampagne closed 1 year ago
Hi, Looking at the code I realized that I do not use InterpolatedUnivariateSpline (and so interpolate.py too) as in the _kvalue_trunc fucntion I use directly the self._hankel(k) function computed with integration method. This way it may trigger some cleaning when the code would be validated.
@jecampagne @EiffL just finished cleaning the files so no diffs with main except for JE new modifications. @jecampagne you should fetch and pull when you get a chance.
Hi, I have proceeded to some re-coding to match the proposed requierements.
Ha, I have found that I should revisit the j0 computation due to the presence of while_loop.
I am investigating the use of a fori_loop with "fixed" bounds. Sounds good...
The fori_loop was ok but it gives a NaN when computing the jacrev at z=0.
In fact I finally use a simple polynomial ratio
@jax.jit
def R(z, num,denom):
return jnp.polyval(num,z)/jnp.polyval(denom,z)
which is a bit slower that the boost poly-rational algo, but as the merit to be jacrev compliant and get the right numerical answer at z=0 This will be implemented in a next commit.
Well, I have a jit-pb If I do:
>>> import jax_galsim as galsim
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>> psf = galsim.Moffat(beta=5.0, flux=0.2, half_light_radius=1.0)
>>> psf
galsim.Moffat(beta=5.0, scale_radius=Array(2.2989595, dtype=float32, weak_type=True), trunc=0.0, flux=0.2, gsparams=galsim.GSParams(128,8192,0.005,5,0.001,1e-05,1e-05,1,0.0001,1e-06,1e-06,1e-08,1e-05))
That's Ok
Now:
>>> import jax
>>> identity = jax.jit(lambda x: x)
>>> identity(galsim.Moffat(beta=5.0, flux=0.2, half_light_radius=1.0))
leads to finally
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 981, in tree_unflatten
return cls(**(children[0]), **aux_data)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/moffat.py", line 103, in __init__
if trunc == 0.0 and beta <= _beta_thr:
File "/pbs/home/c/campagne/my_sps/anaconda3/envs/jaxgalsim/lib/python3.8/site-packages/jax/_src/errors.py", line 179, in __init__
"Abstract tracer value encountered where concrete value is expected: "
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/gsobject.py", line 981, in tree_unflatten
return cls(**(children[0]), **aux_data)
File "/sps/lsst/users/campagne/MyGalSim/JAX-GalSim/jax_galsim/moffat.py", line 115, in __init__
raise _galsim.GalSimIncompatibleValuesError(
galsim.errors.GalSimIncompatibleValuesError: Only one of scale_radius, half_light_radius, or fwhm may be specified Values {'half_light_radius': False, 'scale_radius': False, 'fwhm': False}
and seems to be related to
if trunc == 0.0 and beta <= _beta_thr:
So the __init__
method must be revisited to integrate the argument diagnostics and variable initialization.
Ok. So concerning the jitting of Moffat, one way could be to overload
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
obj = object.__new__(cls)
obj.xxxx = ...
return obj
The concern are then the obj.xxxx = ...
statements.
When you look at the _init_
method, one needs beta
, trunc
, flux
, gsparams
and one-only-one of the following argument not None (scale_radius
, half_light_radius
, fwhm
). Then one computes self._r0
(aka scale_radius) and self._hlr
(aka half_light_radius) and self._fwhm
(fwhm) according to the argument chosen among of the three above mentioned choice. Then, once these parameters are step, the registration is triggered via
super().__init__(
beta=beta,
scale_radius=self._r0,
half_light_radius=self._hlr,
fwhm=self._fwhm,
trunc=trunc,
flux=flux,
gsparams=gsparams,
)
The point is after this registration there is
maxK
with a different method according to trunc=0 or notstepk
with a different method according to beta <= _beta_thr or notand there are some other variables computed: self._maxRrD_sq, self._norm, self._knorm, self._knorm_bis, self._r0_sq , self._inv_r0_sq, self._maxk0, self._stepk0.
So I was wandering how to manage the tree_unflatten
mechanism with the different params/gsparams feetures. Eg. should I trigger the registration super().__init__
with all the variables computed of the form self.xyz?
Well, I probably need to reconsider completely the initialisation and registration , and also use of property-like methods for some deduced variables.
I've completely rewritten the __init__
function following the gaussian-case & w/o tree_unflatten overloading. Then, I have written a series of property
. Some questions arise concerning the use of jax.lax.select(
The current version is passing the test jitting/vmapping
python -m pytest tests/jax/test_vmapping.py::test_moffat_vmapping
. [100%]
1 passed in 1.26s
python -m pytest tests/jax/test_jitting.py::test_moffat_jitting
. [100%]
1 passed in 1.04s
but not sure that 1) if these tests are strong enough , 2) what about jacrev
is a question that remains opened.
Well there are errors concerning the current Moffat implementation dealing with maxK when the pytest enter in the drawing phase with convolution. Now, I do not understand why it the pytest seems to indicated that the "truncated Moffat" is triggered while trunc==0. And then there are JAX errors...
Ok. I made some progress concerning making Moffat working with pytest testing as far as there are triggered so far.
as well as test_jitting & test_vmapping.
But, this is not the end of the story as I have some discrepancy for kValue(PositionD(0.1,-0.2))
computation comparisons wrt the current (ie avalaible when cloning JAX-Galsim) Galsim.Moffat implementation considering the "Truncated version of Moffat". I will investigate...
The "Untruncated Moffat" is ok as fas as I checked, and can be used for further testing (eg. some shear parameters inference)
@jecampagne you should go into the tests/galsim_tests_config.yaml
and remove the line in allowed_failures
that says
"module 'jax_galsim' has no attribute 'Moffat'"
as we have now implemented this. Then you should add the line:
AttributeError: module 'jax_galsim' has no attribute 'Deconvolve'
as we have not yet implemented this function.
Finally, in the enabled_tests
at the top, you should add a line:
- test_moffat.py
which might also enable additional tests. Let us know what you find
Hello, well the Moffat Trunctated part is not working and makes also the Untruncated Moffat failure (certainly due to a lax.cond that is automatically transformed on lax.select which then compile both types of Moffat functions in _kValue(self, kpos) ) when trying to draw as one can see in the nb on Colab . In the past I was playing with Moffat truncated or untruncated quite nicely see here So it is a matter of JAX technology....
In fact in the past at the initialisation phase, le maxk, stepk, and all the function to specialise the truncated vs untruncated form was stored on self. variables eg to point on the right function (ie. self._kV=self._kValue_untrunc) and then used by
def _kValue(self, kpos):
return self._kV(kpos)
Now, the switch is performed as
def _kValue(self, kpos):
return jax.lax.cond(
self.trunc > 0,
lambda x: self._kvalue_trunc(x),
lambda x: self._kValue_untrunc(x),
kpos,
)
I do not know if it is the right thing to do, and if a similar mechanism to the old one can be done using a cashed workspace?
With the current version Moffat has a problem with pytesting as the call of psf.kValue with a single galsim.PositionD is not allowed.
1) this single PositionD(kx,ky) call is not done for instance in case of drawing an image 2) if one wants to compare with GalSim I use
psf_beta = 5.
scale_radius = 2.3
_psf = _galsim.Moffat(beta=psf_beta, flux=1., scale_radius=scale_radius)
print("(3) Galsim psf:",_psf)
print(_psf.beta, _psf.scale_radius, _psf.trunc, _psf.half_light_radius, _psf.fwhm, _psf.maxk, _psf.stepk)
print("xValue: ",_psf._xValue(_galsim.PositionD(0.1,-0.2)))
_kval = _psf._kValue(_galsim.PositionD(0.1,0.2))
assert _kval.imag == 0., "curious.... _kValue with imaginary part"
print("kValue: ",_kval.real)
print("---------------")
psf = galsim.Moffat(beta=psf_beta, flux=1., scale_radius=scale_radius)
print("JaxGalsim psf:",psf)
print(psf.beta, psf.scale_radius, psf.trunc, psf.half_light_radius, psf.fwhm, psf.maxk, psf.stepk)
print("xValue: ",psf.xValue(galsim.PositionD(0.1,-0.2)))
coords = jnp.array([[[ 0.1, -0.2]]])
kval = jax.vmap(lambda *args : psf.kValue(galsim.PositionD(*args)))(coords[...,0],coords[...,1])
print("kValue: ",kval[0,0])
Now, it is certainly a matter of tweaking the JAX-GalSIm:Moffat code to accept the GalSim test...
One possibility but not so "estheatic" (find a more elegant way) could be to tune _kValue_trunc & _kValue_untrunc argument to keep track of dimension, then transform into atleast_1d jax numpy array, and return according to original input structure
def f(x):
nd = jnp.ndim(x)
x = jnp.atleast_1d(x)
res = x**3
if nd==0:
return res[0]
else:
return res
xin = 12.; print(xin, " res=",f(xin))
xin = jnp.array([12.]); print(xin, " res=",f(xin))
leading to
12.0 res= 1728.0
[12.] res= [1728.]
and it would work with jacrev
print(jax.jacrev(f)(-12.), 3*12.**2)
But this is just to be ok to use GalSim test... Ok???? or we design a JAX specific test. Notice that the xValue does not suffer of the same pb.
I will explain the maxk computation philosophy in Galsim using Gaussian 2D profile, and then discuss Moffat (untruncated). So, let us start with a Gaussian profile with isotropic invariance: $$f(r)=C \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)$$ If we define the 2D integrale of $f(r)$ as the flux (noted $F\infty$ to illustrate that $r$ is untruncated), then $$f(r) = \frac{F\infty}{2\pi\sigma^2}\times \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)$$ This is exactly what GalSim computes as xValue (see SBGaussian.cpp). Now, concerning the Fourier 2D representation, one is questioning the definition of the Fourier Transform used by GalSim (always a nightmare with constant definitions!). Well, it turns out that the definition is the so-called (a,b)=(1,-1) in the Mathematica wording, and leads $$f(k) = \frac{F\infty}{2\pi\sigma^2}\times \int_0^\infty J0(kr) \exp(-\frac{1}{2}(\frac{r}{\sigma})^2)\ 2\pi r dr = F\infty \exp(-\frac{1}{2}(k \sigma)^2) $$ and this also what as kValue in the same cpp file as above.
Concerning the "maxk" computation, it turns out that in practice one considers the ratio $f(k)/F_\infty$ and solve the equation in $k$ such the value of the ratio matches a threshold (ie gsparams.maxk_threshold). Let call the maxk_threshold as $gsk$ then $$f(k{max})/F_\infty = gs_k$$ which leads to the GalSim expression for the Gaussian profile.
So far we have all definitions in hand to address the case of the untruncted Moffat profile. If one defines $$f(r) = C \left( 1 + (\frac{r}{rd})^2\right)^{-\beta}$$ then in the same spirit as for the Gaussian profile, the definition of the flux gives the $C$ expression as $$C = \frac{F\infty}{\pi rd^2}(\beta -1)$$ which is exactly what GalSim uses to compute xValue (see SBMoffat.cpp). Now, for k-space expression, following the Fourier Transform definition defined above, leads to $$f(k) = 4 F\infty \frac{\tilde{k}^{\beta-1}}{2^\beta} \frac{K[\beta-1,\tilde{k}]}{\Gamma[\beta-1]}$$ with $\tilde{k} = k\times rd$ (no unit variable) and $K$ Modified Bessel 2nd kind. Which leads to the following assymptotic expression when $k\rightarrow \infty$: $$f(k) \approx F\infty \frac{\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-3/2} \exp(-\tilde{k})$$ To find maxk, then one needs to solve the equation $$gsk = \frac{\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-3/2} \exp(-\tilde{k})$$ leading to the transcendental equation $$x=-\log\alpha + (\beta -3/2)\log(x/2)$$ with $x$ stands for $\tilde{k}$ and $$\alpha = \frac{gs{k} \Gamma(\beta-1)}{\sqrt{\pi}}$$. We can solve this by iteration starting with $x=-\log\alpha$.
So far so good, BUT GalSim uses still a wrong asymptotic expression which gives $$gs_k = \frac{2\sqrt{\pi}}{\Gamma[\beta-1]} \left(\frac{\tilde{k}}{2} \right)^{\beta-1/2} \exp(-\tilde{k})$$ which then turns to solve $$x=-\log a + (\beta -1/2)\log x$$ with $$a = \frac{gs_k \Gamma[\beta-1]}{2\sqrt{\pi}} 2^{\beta-1/2}$$
So, for the time beeing I have coded the GalSim way to compute $maxK$ to get the GalSim vs JAX-Galsim comparison as close as possible but I've also kept commented what I would consider the right computations and rise an issue in the GalSim repository see #1208. @rmjarvis has some arguments to keep the current way to compute maxk.
FYI, #1208 has already been fixed in the main branch of GalSim. cf. https://github.com/GalSim-developers/GalSim/pull/1210
But note that we don't use the corrected asymptotic expansion either. It turns out that neither formula actually gives very good results, so we just use the full calculation to compute maxk without approximation.
What I would like to propose also is the following: use the self._workspace = {}
mechanism used in jax_cosmo/core.py:Cosmology to store the maxk and stepk. Eg.
# in __init__
.....
self._ws = {} # not store in params nor gsparams
......
# then use self._ws to store at first use the value of maxk
@property
def _maxk(self):
if "_maxk" not in self._ws:
res = jax.lax.select(self.trunc > 0, self._maxk_trunc, self._maxk_untrunc)
self._ws["_mask"] = res
return self._ws["_mask"]
What do you think @EiffL ? I have made a demo which passes our jitting and vmapping tests.
ho! sorry I miss @rmjarvis comment, thanks. So does the new computation is activated when we use for JAX-GalSim installation : install_requires "galsim >= 2.3.0"?
One add: looking at Galsim/releases/2.4 test_moffat.py file, it seems that "maxk" numerical result is only tested in the "truncated case". It is right? Would not it valuable to be extended to un_truncated case too?
install_requires "galsim >= 2.3.0"
No. It will be released with 2.5.0, but we're not quite ready for that yet. (The current version is 2.4.11.)
Hi @rmjarvis
May be this msg would be more adequat to GalSim repo (let me know) but looking to a possible evolution of the Moffat code for JAX-GAlsim, I was looking to the SBMoffat.cpp code in the GalSim main branch. I am curious
// Set maxK to the value where the FT is down to maxk_threshold
double SBMoffat::SBMoffatImpl::maxK() const
{
if (_maxk == 0.) {
if (_trunc == 0.) {
MoffatMaxKSolver func(this, this->gsparams.maxk_threshold);
double ksq1 = 0.;
double ksq2 = 100; // k=10 is usually close, so it makes a good starting guess.
Solve<MoffatMaxKSolver> solver(func,ksq1,ksq2);
....
do not you feel that the ksq2=100 (ie upper bound of k is 10) is a bit low? ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.
ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.
Did you try it? I'm pretty sure it gets that right.
ie. for instance rd = 2., beta=30 and k_threshold = 1e-3 would leads to a maxK of 15 or so => maxK^2 = 225.
Did you try it? I'm pretty sure it gets that right.
I hopefully manage to get running the main-branch Galsim and so you're right the code find maxK=14.708. So I guess the Finding algo to push the upperbound further is doing the job.
We can close the thread. Thanks @rmjarvis
@jecampagne @EiffL I'm merging now to avoid deprecation, I think Moffat passes almost all tests in test_moffat.py
. I will make an issue to document tests that can still be activated
This is an import of the code I have developed in my original fork of the repo. I have commented the
_drawKImage
method using mydraw_by_kValue
implementattion Now, I have introducedbessel.py
,integrate.py
andinterpolate.py
I have tried to mimic Galsim code both for Truncated and Non-Truncated Moffat but see Galsim issues #1208 #1209