Closed jecampagne closed 9 months ago
I do not understand such message in the pytest used in the PR causing test failure
" test_api_gsobject[docs-methods]:
self = <[AttributeError("'Spergel' object has no attribute 'beta'") raised in repr()] Spergel object at 0x7f77ae66dd90>
Now test_api.py have detected that my xValue is not/ jit-grad friendly. I must investigate.
After some tests the problem seems to be in fsmallz_nu(z, nu) as
def new_xValue(self, pos):
r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0
#jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * _Knu(nu, z))
res = jnp.where(r <= 1.0e-10,
jnp.power(r, self.nu) * galsim.spergel._Knu(self.nu, r),
jnp.power(r, self.nu) * galsim.spergel._Knu(self.nu, r)
)
return self._xnorm * res
galsim.spergel.Spergel._xValue = new_xValue
Is ok.
Well I'm a bit puzzled but I guess this is related to https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
in fact fsmallz_nu cannot be used with integer nu so then the gardient. So I need to protect it according to the above doc. The pb is that jnp.power(z, nu) * _Knu(nu, z) is not accurate for small value of z which is a pb for very sharp Spergel profile.
A working hack is (here a code tested in Google colab)
def new_fsmallz_nu(z, nu):
def fnu(z, nu):
"""z^nu K_nu[z] z -> 0 O(z^4) z > 0"""
nu += 1.e-10 #to garanty that nu is not an integer
z2 = z * z
z4 = z2 * z2
c1 = jnp.power(2.0, -6.0 - nu)
c2 = galsim.spergel._gamma(-2.0 - nu)
c3 = galsim.spergel._gamma(-2.0 + nu)
c4 = jnp.power(z, 2.0 * nu)
c5 = z4 * 8.0 * z2 * (2.0 + nu) + 32.0 * (1.0 + nu) * (2.0 + nu)
c6 = z2 * (16.0 + z2 - 8.0 * nu) * c3
return c1 * (c4 * c5 * c2 + jnp.power(4.0, nu) * (c6 + 32.0 * galsim.spergel._gamma(nu)))
return jnp.select(
[nu == 0, nu == 1, nu == 2, nu == 3, nu == 4],
[galsim.spergel.f0(z), galsim.spergel.f1(z), fgalsim.spergel.2(z), galsim.spergel.f3(z), galsim.spergel.f4(z)],
default=fnu(z,nu),
)
notice that in fact the NaN originate from the jnp.where used in
def fz_nu(z, nu):
"""z^nu K_nu[z], z > 0"""
return jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * _Knu(nu, z))
I'm trying to investigating the pytest Failure on test_serpel.py::test_spergel do_shoot for nu>-0.3. But if I use locally at the root of JAX-Galsim (Python 3.11.7):
python tests/GalSim/tests/test_spergel.py
there is no crash/failure detected, while if I use
pytest -k test_spergel.py
then there is a Failure detection (the same as in the PR automatic processing). Is there a simple explanation of the differences between the two cmds?
Here a more detailed investigation of what it is printed before pytest short test summary, and only for nu=0.0 for which the do_shoot is triggered:
Testing Spergel with nu=0.000000
nyquist_scale, stepk, maxk = 0.07906413946811709 0.6283185307179586 39.734735301288545
kimage scale,bounds = 0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
k flux: 1.0 1.0 (1+0j)
k: i,j = 2 3 (0.23544140723892693+0j) 0.23544140723892693
k: i,j = -4 1 (0.19060241763721406+0j) 0.19060241763721406
k: i,j = 0 -5 (0.1380283775837044+0j) 0.1380283775837044
k: i,j = -3 -3 (0.18193996090252673+0j) 0.18193996090252673
0.9753151883124384 [Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True)]
0.4832483199947 [Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True)]
0.0014147098942801993 [Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True)]
Testing Rotated Spergel with nu=0.000000
nyquist_scale, stepk, maxk = 0.07906413946811708 0.6283185307179587 39.73473530128855
kimage scale,bounds = 0.6283185307179587 galsim.BoundsI(-64,64,-64,64)
k flux: 1.0 (1+0j) (1+0j)
k: i,j = 2 3 (0.23544140723892684+0j) (0.23544140723892684+0j)
k: i,j = -4 1 (0.190602417637214+0j) (0.190602417637214+0j)
k: i,j = 0 -5 (0.13802837758370434+0j) (0.13802837758370434+0j)
k: i,j = -3 -3 (0.18193996090252665+0j) (0.18193996090252665+0j)
(0.9753151883124384+0j) [Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128)]
(0.4832483199947+0j) [Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128)]
(0.0014147098942801993+0j) [Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128)]
New, saved array sizes: (16, 16) (16, 16)
Sum of values: 0.766886442899704 0.7671570174861699
Minimum image value: 0.00051524624 0.00051525916
Maximum image value: 0.018770548 0.01884067
Peak location: 119 119
Moments Mx, My, Mxx, Myy, Mxy for new array:
8.5 8.5 11.298789 11.298789 0
Moments Mx, My, Mxx, Myy, Mxy for saved array:
8.5 8.5 11.294338 11.294338 0
Start do_shoot
prof.flux = 1.0
flux_max = 0.011384148
flux_tot = 0.9587079951379565
nphot = 210535.73408575248
img2.sum => 0.9999976148537826
img2.max = 0.06832491
New, saved array sizes: (31, 31) (31, 31)
Sum of values: 0.9999976148537826 0.9587079951379565
Minimum image value: 0.0 2.776471e-05
Maximum image value: 0.06832491 0.011384148
Peak location: 480 480
Moments Mx, My, Mxx, Myy, Mxy for new array:
-0.0016054174 -0.0012776956 2.3415176 2.336859 -0.0023389464
Moments Mx, My, Mxx, Myy, Mxy for saved array:
0 0 25.369039 25.369039 0
and now in the direct python call
Testing Spergel with nu=0.000000
nyquist_scale, stepk, maxk = 0.07906413946811708 0.6283185307179586 39.73473530128855
kimage scale,bounds = 0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
k flux: 1.0 1.0 (1+0j)
k: i,j = 2 3 (0.235441407238927+0j) 0.23544140723892704
k: i,j = -4 1 (0.19060241763721414+0j) 0.19060241763721417
k: i,j = 0 -5 (0.13802837758370443+0j) 0.13802837758370445
k: i,j = -3 -3 (0.18193996090252681+0j) 0.1819399609025268
0.9753151883124384 [0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384]
0.48324831999470014 [0.48324831999470014, 0.48324831999470025, 0.48324831999470014, 0.48324831999470025, 0.48324831999470014]
0.0014147098942802001 [0.0014147098942802001, 0.0014147098942802003, 0.0014147098942802003, 0.0014147098942801997, 0.0014147098942801997]
Testing Rotated Spergel with nu=0.000000
nyquist_scale, stepk, maxk = 0.07906413946811708 0.6283185307179586 39.73473530128855
kimage scale,bounds = 0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
k flux: 1.0 1.0 (1+0j)
k: i,j = 2 3 (0.2354414072389168+0j) 0.23544140723892698
k: i,j = -4 1 (0.1906024176372209+0j) 0.1906024176372141
k: i,j = 0 -5 (0.13802837758370645+0j) 0.13802837758370445
k: i,j = -3 -3 (0.18193996090253395+0j) 0.1819399609025268
0.9753151883124384 [0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384]
0.48324831999470014 [0.48324831999470014, 0.48324831999470014, 0.48324831999470025, 0.48324831999470025, 0.48324831999470014]
0.0014147098942802001 [0.0014147098942802001, 0.0014147098942802003, 0.0014147098942802006, 0.0014147098942802001, 0.0014147098942802001]
New, saved array sizes: (16, 16) (16, 16)
Sum of values: 0.7668864142615348 0.7671569886151701
Minimum image value: 0.00051524624 0.00051525916
Maximum image value: 0.018770548 0.018840669
Peak location: 119 119
Moments Mx, My, Mxx, Myy, Mxy for new array:
8.5 8.5 11.298789 11.298789 0
Moments Mx, My, Mxx, Myy, Mxy for saved array:
8.5 8.5 11.294338 11.294338 -4.9318432e-08
Start do_shoot
prof.flux = 1.0
flux_max = 0.011384147
flux_tot = 0.9587079531120253
nphot = 210535.74208036572
img2.sum => 0.9589061376536847
img2.max = 0.010995669
New, saved array sizes: (31, 31) (31, 31)
Sum of values: 0.9589061376536847 0.9587079531120253
Minimum image value: 1.42493445e-05 2.7764725e-05
Maximum image value: 0.010995669 0.011384147
Peak location: 480 480
Moments Mx, My, Mxx, Myy, Mxy for new array:
0.015340344 -0.019213871 25.283831 25.47646 0.10966911
Moments Mx, My, Mxx, Myy, Mxy for saved array:
0 0 25.369039 25.369038 0
.... continuing as no failure is detected
It is clear that in the PyTest process, the Moments do not agree for the last cases, which certainly is a sign of problem in the image array.
Well, but why in the "direct python call" of the test_spergel.py the moment agree...
The direct call doesn't replace galsim with jax_galsim.
Ok, I see.
Making progress: the do_shoot test for nu>-0.3 is ok. Now there is a assert Fail later. So work is still needed.
I have made progress to get test_serpel.py::test_spergel do_shoot working. Now in the test_serpel.py::test_spergel_shoot : nu >0 tests are ok. The test with
obj = galsim.Spergel(nu=-0.6, half_light_radius=3.5, flux=1.e4)
fails as added_flux
from
added_flux, photons = obj.drawPhot(im, poisson_flux=False, rng=rng.duplicate())
differs from obj.flux
:
obj.flux = 10000.0
added_flux = 9999.0
Now, looking at SBSpergel.cpp::SpergelInfo::shoot
C++ function Galsim comment is in the context of nu<=0
// exact s.b. profile diverges at origin, so replace the inner most circle
// (defined such that enclosed flux is shoot_acccuracy) with a linear function
// that contains the same flux and has the right value at r = rmin.
// So need to solve the following for a and b:
// int(2 pi r (a + b r) dr, 0..rmin) = shoot_accuracy
// a + b rmin = K_nu(rmin) * rmin^nu
My comment is the following:
1) The Spergel profile noted I_{nu}(r) = xnorm(nu) * f(r/r0,nu)
is such that
$$\Large f(z,\nu)=z^\nu K_\nu(z)$$
as the following property for z->0
(or r->0
for I_{nu}(r)
)
$$
\begin{cases}\Large
f(z,\nu>0) = 2^{\nu -1} \Gamma(\nu)\\
\Large f(z,\nu=0) \propto \log(1/z) \\
\Large f(z,\nu<0) \propto 1/z^{-2\nu}
\end{cases}
$$
So, I agree that I_{nu}(r)
is diverging for nu<=0
.
2) But the cumulated integrated flux is proportional to
$$\Large \int_0^z f(z,\nu) z dz$$
has no divergence at z=0
as nu>-1
(which is the original paper bound) as nu is in [-0.85, 4.0] (GalSim bounds).
So, I do not see the need to make the procedure mentionned in GalSim and I do suspect that it introduces the difference detected by pytest.
Now, despite what I've written above, I will try to implement the GalSim procedure.
I have opened an issue in the GalSim repo to understand how they manage the case nu<0 as I may found a pb in the normalisation. Let see what Mike will tell me.
I have coded shoot
for nu<=0
"as it is" in Galsim. The code pass Galsim/tests/test_spergel.py (ie.
pytest -k test_spergel.py`). Now for me there is a pb with this code (I do not speak about optimisation as mentioned by Matthew):
There is no reason to do not use a inversion method of the cumulative flux for nu<=0
as for nu>=-0.85
(Galsim bound) the
\int_ 0^r u I(u) du
is a non-diverging integral when r->0
.
The Galsim way to treat the generation r
when nu<=0
relies on a linear approximation of I(r) = a + b r
when r<rmin
with rmin
defined by the shoot_accuracy parameter. Going throw the math and looking the Galsim implementation I have detected a normalization error that leads to wrong (a,b)
numerical determination. The error leads to the 0.1% error on the flux detected by pytest.
Now either we stick to the Galsim coding because I may be wrong in interpreting the Galsim coding, or we adopt a more natural way to proceed as from point (1) there is no matter to distinguish nu>0
and nu<=0
.
After deciding, I have to optimize the shoot process along Matthiew remark. (Rq. the full pytest takes 40min or so).
Hi, I have implemented a shoot_neg
(ie. for nu<0) such that 1) it follows the Galsim spirit of using a linear approx of the profile for low r value and 2) it uses a normalization which passes the pytest and is "math. correct".
I have also use the inversion of the cumulative flux along the exponential use-case (ie. lazy_property deco & linear interpolation inversion). The whole pytest last 23min now.
There is still a missing test to assert that nu parameter is in [-0.86, 4.0].
I'll take a look tomorrow!
I have rename "Knu" (modified Bessel 2nd kind) as "kv" to follow Galsim notation. "kv" is moved to jax_galsim/bessel.py
. I have realised then that Knu/kv was defined & used for Moffat, so I have remove the code to use the new one. I have not moved other z^nu Knu(nu,z) related functions as there are specific to Spergel profile, nor the Gamma function.
Now in the core
directory there is a "bessel.py" file that I had created in the past. I think this should be removed and the code moved to jax_galsim/bessel.py
. What do you think?
Now in the coredirectory there is a "bessel.py" file that I had created in the past. I think this should be removed and the code moved to jax_galsim/bessel.py. What do you think?
No we should not do this. The goal is to match the galsim API exactly. So we only put bessel functions that galsim provides in jax_galsim/bessel.py. Anything extra, we can add to core since this contains private APIs to jax_galsim.
Who press the Merge pull request
?
Either! I'll go ahead!
This is an implementation of the Spergel profile. In the Galsim code the Spergel index (nu) is restricted to [-0.85, 4.0]. In the JAX code I do not know to assert a use of non valid nu input parameter.
Notice that nu=0.5 is exactly the Exponential profile. Spergel profile might be a good alternative to the Sersic profile as realized here.
The present code pass test_spergel.py Galsim tests excepting the photon shooting even if I have setup a first version using a inversion method using JAX-Galsim bisection root algo.
A side remark. For Moffat & Spergel we need K_nu Bessel function. For the time beeing this Bessel func is defined both in moffat.py and spergel.py. I guess this is not optimal and this function would have to be defined in the core directory.