GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
27 stars 3 forks source link

Spergel #86

Closed jecampagne closed 9 months ago

jecampagne commented 9 months ago

This is an implementation of the Spergel profile. In the Galsim code the Spergel index (nu) is restricted to [-0.85, 4.0]. In the JAX code I do not know to assert a use of non valid nu input parameter.

Notice that nu=0.5 is exactly the Exponential profile. Spergel profile might be a good alternative to the Sersic profile as realized here.

The present code pass test_spergel.py Galsim tests excepting the photon shooting even if I have setup a first version using a inversion method using JAX-Galsim bisection root algo.

A side remark. For Moffat & Spergel we need K_nu Bessel function. For the time beeing this Bessel func is defined both in moffat.py and spergel.py. I guess this is not optimal and this function would have to be defined in the core directory.

jecampagne commented 9 months ago

I do not understand such message in the pytest used in the PR causing test failure

" test_api_gsobject[docs-methods]: 
self = <[AttributeError("'Spergel' object has no attribute 'beta'") raised in repr()] Spergel object at 0x7f77ae66dd90>
jecampagne commented 9 months ago

Now test_api.py have detected that my xValue is not/ jit-grad friendly. I must investigate.

After some tests the problem seems to be in fsmallz_nu(z, nu) as

def new_xValue(self, pos):
        r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0
        #jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * _Knu(nu, z))
        res = jnp.where(r <= 1.0e-10, 
                        jnp.power(r, self.nu) * galsim.spergel._Knu(self.nu, r),
                        jnp.power(r, self.nu) * galsim.spergel._Knu(self.nu, r)
        )
        return self._xnorm * res
galsim.spergel.Spergel._xValue = new_xValue

Is ok.

jecampagne commented 9 months ago

Well I'm a bit puzzled but I guess this is related to https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

in fact fsmallz_nu cannot be used with integer nu so then the gardient. So I need to protect it according to the above doc. The pb is that jnp.power(z, nu) * _Knu(nu, z) is not accurate for small value of z which is a pb for very sharp Spergel profile.

A working hack is (here a code tested in Google colab)

def new_fsmallz_nu(z, nu):
  def fnu(z, nu):
        """z^nu K_nu[z] z -> 0 O(z^4) z > 0"""
        nu += 1.e-10 #to garanty that nu is not an integer
        z2 = z * z
        z4 = z2 * z2
        c1 = jnp.power(2.0, -6.0 - nu)
        c2 = galsim.spergel._gamma(-2.0 - nu)
        c3 = galsim.spergel._gamma(-2.0 + nu)
        c4 = jnp.power(z, 2.0 * nu)
        c5 = z4 * 8.0 * z2 * (2.0 + nu) + 32.0 * (1.0 + nu) * (2.0 + nu)
        c6 = z2 * (16.0 + z2 - 8.0 * nu) * c3
        return c1 * (c4 * c5 * c2 + jnp.power(4.0, nu) * (c6 + 32.0 * galsim.spergel._gamma(nu)))
  return jnp.select(
        [nu == 0, nu == 1, nu == 2, nu == 3, nu == 4],
        [galsim.spergel.f0(z), galsim.spergel.f1(z), fgalsim.spergel.2(z), galsim.spergel.f3(z), galsim.spergel.f4(z)],
        default=fnu(z,nu),
  )

notice that in fact the NaN originate from the jnp.where used in

def fz_nu(z, nu):
    """z^nu K_nu[z], z > 0"""
    return jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * _Knu(nu, z))
jecampagne commented 9 months ago

I'm trying to investigating the pytest Failure on test_serpel.py::test_spergel do_shoot for nu>-0.3. But if I use locally at the root of JAX-Galsim (Python 3.11.7):

python tests/GalSim/tests/test_spergel.py

there is no crash/failure detected, while if I use

 pytest -k test_spergel.py

then there is a Failure detection (the same as in the PR automatic processing). Is there a simple explanation of the differences between the two cmds?

Here a more detailed investigation of what it is printed before pytest short test summary, and only for nu=0.0 for which the do_shoot is triggered:

Testing Spergel with nu=0.000000
  nyquist_scale, stepk, maxk =  0.07906413946811709 0.6283185307179586 39.734735301288545
  kimage scale,bounds =  0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
  k flux:  1.0 1.0 (1+0j)
  k: i,j =  2 3 (0.23544140723892693+0j) 0.23544140723892693
  k: i,j =  -4 1 (0.19060241763721406+0j) 0.19060241763721406
  k: i,j =  0 -5 (0.1380283775837044+0j) 0.1380283775837044
  k: i,j =  -3 -3 (0.18193996090252673+0j) 0.18193996090252673
0.9753151883124384 [Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True), Array(0.97531519, dtype=float64, weak_type=True)]
0.4832483199947 [Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True), Array(0.48324832, dtype=float64, weak_type=True)]
0.0014147098942801993 [Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True), Array(0.00141471, dtype=float64, weak_type=True)]
Testing Rotated Spergel with nu=0.000000
  nyquist_scale, stepk, maxk =  0.07906413946811708 0.6283185307179587 39.73473530128855
  kimage scale,bounds =  0.6283185307179587 galsim.BoundsI(-64,64,-64,64)
  k flux:  1.0 (1+0j) (1+0j)
  k: i,j =  2 3 (0.23544140723892684+0j) (0.23544140723892684+0j)
  k: i,j =  -4 1 (0.190602417637214+0j) (0.190602417637214+0j)
  k: i,j =  0 -5 (0.13802837758370434+0j) (0.13802837758370434+0j)
  k: i,j =  -3 -3 (0.18193996090252665+0j) (0.18193996090252665+0j)
(0.9753151883124384+0j) [Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128), Array(0.97531519+0.j, dtype=complex128)]
(0.4832483199947+0j) [Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128), Array(0.48324832+0.j, dtype=complex128)]
(0.0014147098942801993+0j) [Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128), Array(0.00141471+0.j, dtype=complex128)]
New, saved array sizes:  (16, 16) (16, 16)
Sum of values:  0.766886442899704 0.7671570174861699
Minimum image value:  0.00051524624 0.00051525916
Maximum image value:  0.018770548 0.01884067
Peak location:  119 119
Moments Mx, My, Mxx, Myy, Mxy for new array: 
      8.5              8.5              11.298789        11.298789        0              
Moments Mx, My, Mxx, Myy, Mxy for saved array: 
      8.5              8.5              11.294338        11.294338        0              
Start do_shoot
prof.flux =  1.0
flux_max =  0.011384148
flux_tot =  0.9587079951379565
nphot =  210535.73408575248
img2.sum =>  0.9999976148537826
img2.max =  0.06832491
New, saved array sizes:  (31, 31) (31, 31)
Sum of values:  0.9999976148537826 0.9587079951379565
Minimum image value:  0.0 2.776471e-05
Maximum image value:  0.06832491 0.011384148
Peak location:  480 480
Moments Mx, My, Mxx, Myy, Mxy for new array: 
      -0.0016054174    -0.0012776956    2.3415176        2.336859         -0.0023389464  
Moments Mx, My, Mxx, Myy, Mxy for saved array: 
      0                0                25.369039        25.369039        0  

and now in the direct python call

Testing Spergel with nu=0.000000
  nyquist_scale, stepk, maxk =  0.07906413946811708 0.6283185307179586 39.73473530128855
  kimage scale,bounds =  0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
  k flux:  1.0 1.0 (1+0j)
  k: i,j =  2 3 (0.235441407238927+0j) 0.23544140723892704
  k: i,j =  -4 1 (0.19060241763721414+0j) 0.19060241763721417
  k: i,j =  0 -5 (0.13802837758370443+0j) 0.13802837758370445
  k: i,j =  -3 -3 (0.18193996090252681+0j) 0.1819399609025268
0.9753151883124384 [0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384]
0.48324831999470014 [0.48324831999470014, 0.48324831999470025, 0.48324831999470014, 0.48324831999470025, 0.48324831999470014]
0.0014147098942802001 [0.0014147098942802001, 0.0014147098942802003, 0.0014147098942802003, 0.0014147098942801997, 0.0014147098942801997]
Testing Rotated Spergel with nu=0.000000
  nyquist_scale, stepk, maxk =  0.07906413946811708 0.6283185307179586 39.73473530128855
  kimage scale,bounds =  0.6283185307179586 galsim.BoundsI(-64,64,-64,64)
  k flux:  1.0 1.0 (1+0j)
  k: i,j =  2 3 (0.2354414072389168+0j) 0.23544140723892698
  k: i,j =  -4 1 (0.1906024176372209+0j) 0.1906024176372141
  k: i,j =  0 -5 (0.13802837758370645+0j) 0.13802837758370445
  k: i,j =  -3 -3 (0.18193996090253395+0j) 0.1819399609025268
0.9753151883124384 [0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384, 0.9753151883124384]
0.48324831999470014 [0.48324831999470014, 0.48324831999470014, 0.48324831999470025, 0.48324831999470025, 0.48324831999470014]
0.0014147098942802001 [0.0014147098942802001, 0.0014147098942802003, 0.0014147098942802006, 0.0014147098942802001, 0.0014147098942802001]
New, saved array sizes:  (16, 16) (16, 16)
Sum of values:  0.7668864142615348 0.7671569886151701
Minimum image value:  0.00051524624 0.00051525916
Maximum image value:  0.018770548 0.018840669
Peak location:  119 119
Moments Mx, My, Mxx, Myy, Mxy for new array: 
      8.5              8.5              11.298789        11.298789        0              
Moments Mx, My, Mxx, Myy, Mxy for saved array: 
      8.5              8.5              11.294338        11.294338        -4.9318432e-08 
Start do_shoot
prof.flux =  1.0
flux_max =  0.011384147
flux_tot =  0.9587079531120253
nphot =  210535.74208036572
img2.sum =>  0.9589061376536847
img2.max =  0.010995669
New, saved array sizes:  (31, 31) (31, 31)
Sum of values:  0.9589061376536847 0.9587079531120253
Minimum image value:  1.42493445e-05 2.7764725e-05
Maximum image value:  0.010995669 0.011384147
Peak location:  480 480
Moments Mx, My, Mxx, Myy, Mxy for new array: 
      0.015340344      -0.019213871     25.283831        25.47646         0.10966911     
Moments Mx, My, Mxx, Myy, Mxy for saved array: 
      0                0                25.369039        25.369038        0              
.... continuing as no failure is detected

It is clear that in the PyTest process, the Moments do not agree for the last cases, which certainly is a sign of problem in the image array.

Well, but why in the "direct python call" of the test_spergel.py the moment agree...

beckermr commented 9 months ago

The direct call doesn't replace galsim with jax_galsim.

jecampagne commented 9 months ago

Ok, I see.

jecampagne commented 9 months ago

Making progress: the do_shoot test for nu>-0.3 is ok. Now there is a assert Fail later. So work is still needed.

jecampagne commented 9 months ago

I have made progress to get test_serpel.py::test_spergel do_shoot working. Now in the test_serpel.py::test_spergel_shoot : nu >0 tests are ok. The test with

 obj = galsim.Spergel(nu=-0.6, half_light_radius=3.5, flux=1.e4)

fails as added_flux from

added_flux, photons = obj.drawPhot(im, poisson_flux=False, rng=rng.duplicate())

differs from obj.flux:

obj.flux =  10000.0
added_flux =  9999.0

Now, looking at SBSpergel.cpp::SpergelInfo::shoot C++ function Galsim comment is in the context of nu<=0

     // exact s.b. profile diverges at origin, so replace the inner most circle
     // (defined such that enclosed flux is shoot_acccuracy) with a linear function
    // that contains the same flux and has the right value at r = rmin.
    // So need to solve the following for a and b:
    // int(2 pi r (a + b r) dr, 0..rmin) = shoot_accuracy
    // a + b rmin = K_nu(rmin) * rmin^nu

My comment is the following: 1) The Spergel profile noted I_{nu}(r) = xnorm(nu) * f(r/r0,nu) is such that

$$\Large f(z,\nu)=z^\nu K_\nu(z)$$ 

as the following property for z->0 (or r->0 for I_{nu}(r))

$$
\begin{cases}\Large
 f(z,\nu>0) = 2^{\nu -1} \Gamma(\nu)\\
\Large f(z,\nu=0) \propto \log(1/z) \\
\Large f(z,\nu<0) \propto 1/z^{-2\nu}
\end{cases}
$$

So, I agree that I_{nu}(r) is diverging for nu<=0. 2) But the cumulated integrated flux is proportional to

$$\Large \int_0^z  f(z,\nu) z dz$$

has no divergence at z=0 as nu>-1 (which is the original paper bound) as nu is in [-0.85, 4.0] (GalSim bounds).

So, I do not see the need to make the procedure mentionned in GalSim and I do suspect that it introduces the difference detected by pytest.

Now, despite what I've written above, I will try to implement the GalSim procedure.

jecampagne commented 9 months ago

I have opened an issue in the GalSim repo to understand how they manage the case nu<0 as I may found a pb in the normalisation. Let see what Mike will tell me.

jecampagne commented 9 months ago

I have coded shoot for nu<=0 "as it is" in Galsim. The code pass Galsim/tests/test_spergel.py (ie.pytest -k test_spergel.py`). Now for me there is a pb with this code (I do not speak about optimisation as mentioned by Matthew):

  1. There is no reason to do not use a inversion method of the cumulative flux for nu<=0 as for nu>=-0.85 (Galsim bound) the

    \int_ 0^r   u I(u) du  

    is a non-diverging integral when r->0.

  2. The Galsim way to treat the generation r when nu<=0 relies on a linear approximation of I(r) = a + b r when r<rmin with rmin defined by the shoot_accuracy parameter. Going throw the math and looking the Galsim implementation I have detected a normalization error that leads to wrong (a,b) numerical determination. The error leads to the 0.1% error on the flux detected by pytest.

Now either we stick to the Galsim coding because I may be wrong in interpreting the Galsim coding, or we adopt a more natural way to proceed as from point (1) there is no matter to distinguish nu>0 and nu<=0. After deciding, I have to optimize the shoot process along Matthiew remark. (Rq. the full pytest takes 40min or so).

jecampagne commented 9 months ago

Hi, I have implemented a shoot_neg(ie. for nu<0) such that 1) it follows the Galsim spirit of using a linear approx of the profile for low r value and 2) it uses a normalization which passes the pytest and is "math. correct". I have also use the inversion of the cumulative flux along the exponential use-case (ie. lazy_property deco & linear interpolation inversion). The whole pytest last 23min now.

jecampagne commented 9 months ago

There is still a missing test to assert that nu parameter is in [-0.86, 4.0].

beckermr commented 9 months ago

I'll take a look tomorrow!

jecampagne commented 9 months ago

I have rename "Knu" (modified Bessel 2nd kind) as "kv" to follow Galsim notation. "kv" is moved to jax_galsim/bessel.py. I have realised then that Knu/kv was defined & used for Moffat, so I have remove the code to use the new one. I have not moved other z^nu Knu(nu,z) related functions as there are specific to Spergel profile, nor the Gamma function.

Now in the coredirectory there is a "bessel.py" file that I had created in the past. I think this should be removed and the code moved to jax_galsim/bessel.py. What do you think?

beckermr commented 9 months ago

Now in the coredirectory there is a "bessel.py" file that I had created in the past. I think this should be removed and the code moved to jax_galsim/bessel.py. What do you think?

No we should not do this. The goal is to match the galsim API exactly. So we only put bessel functions that galsim provides in jax_galsim/bessel.py. Anything extra, we can add to core since this contains private APIs to jax_galsim.

jecampagne commented 9 months ago

Who press the Merge pull request?

beckermr commented 9 months ago

Either! I'll go ahead!