shivampcosmo / GODMAX

Gas thermODymanics and Matter distribution using jAX
0 stars 0 forks source link

Bug in get_Pnt_fac function #1

Open chto opened 1 year ago

chto commented 1 year ago

Hi @shivampcosmo,

I think there is a bug in https://github.com/shivampcosmo/GODMAX/blob/a50f098f4a2acd6fcae5018abf0cd5fac92ca403/src/get_BCMP_profile_jit.py#L468 The argument should be jr, jc, jz, JM instead of jr, jc, jM, jz

chto commented 1 year ago

Actually, I think there is another bug at https://github.com/shivampcosmo/GODMAX/blob/a50f098f4a2acd6fcae5018abf0cd5fac92ca403/src/get_power_spectra_jit.py#L149

It should be z = self.z_array[jz]

shivampcosmo commented 1 year ago

oops you are right. Found the Pnt one before but forgot push it, but completely missed the lensing kernel one! Pushed the corrections. Thanks! Lmk if you find something else.

chto commented 1 year ago

Yep. Others are less critical. Just some personal preference.

  1. I would do Tinker 2010 HMF instead of Tinker 2008. This is because the integral constraints b(nu)*f(nu) dnu =1 is not satisfied if you mix TInker 2008 HMF and 2010 bias. Here is the code.

    @partial(jit, static_argnums=(0,))
    def get_fsigma_Mz(self, jz, jM, mdef_delta=200):
        '''Tinker 2010 mass function'''
        sigma = self.sigma_Mz_mat[jz, jM]
        delta_c = constants.DELTA_COLLAPSE
        nu = delta_c / sigma
        z = self.z_array[jz]
        rho_treshold = mdef_delta * self.get_rho_c(z)
        Delta_m = round(rho_treshold / self.get_rho_m(z))
        fit_Delta = jnp.array([200, 300, 400, 600, 800, 1200, 1600, 2400, 3200])
        fit_alpha = jnp.array([0.368, 0.363, 0.385, 0.389, 0.393, 0.365, 0.379, 0.355, 0.327])
        fit_beta = jnp.array([0.589, 0.585, 0.544, 0.543, 0.564, 0.623, 0.637, 0.673, 0.702])
        fit_gamma =  jnp.array([0.864, 0.922, 0.987, 1.09, 1.20, 1.34, 1.50, 1.68, 1.81])
        fit_phi = jnp.array([-0.729, -0.789, -0.910, -1.05, -1.20, -1.26, -1.45, -1.50, -1.49])
        fit_eta = jnp.array([-0.243, -0.261, -0.261, -0.273, -0.278, -0.301, -0.301, -0.319, -0.336])
        alpha = jnp.interp(Delta_m, fit_Delta, fit_alpha)
        beta = jnp.interp(Delta_m, fit_Delta, fit_beta)
        gamma = jnp.interp(Delta_m, fit_Delta, fit_gamma)
        phi = jnp.interp(Delta_m, fit_Delta, fit_phi)
        eta = jnp.interp(Delta_m, fit_Delta, fit_eta)
    
        beta = beta*(1+z)**0.2
        phi = phi*(1+z)**(-0.08)
        eta = eta*(1+z)**0.27
        gamma = gamma*(1+z)**(-0.01)
        fnu= alpha*(1+(beta*nu)**(-2.0*phi))*nu**(2*eta)*jnp.exp(-gamma*nu**2/2)
        return nu*fnu
  2. I found that xip and xim transform are not stable using trapz. I think using mcfit is necessary.

Here is some of the code

    def get_Hankel_xip(self, jt):

        H = Hankel(self.ell_array,nu=0, lowring=True, backend="jax")
        y, G = H(self.Cl_kappa_kappa_halofit_mat, extrap=False)
        yreturn = []
        for i in range(len(G)):
            for j in range(len(G)):
                yreturn.append(jnp.interp(self.angles_data_array[jt], y/((1/60.) * (jnp.pi/180.)), G[i][j]))
        yreturn = jnp.array(yreturn).T.reshape(len(G), len(G))/(2*jnp.pi)
        return yreturn
  1. Transform rho(r) to u(k) can also be done with mcfit

Here is some code

    def get_ukdmb_from_rho(self, jk):

        k = self.kPk_array[jk]
        prefac = 4 * jnp.pi*jnp.sqrt(jnp.pi/2)*jnp.ones_like(self.r_array)
        prefac_repeat_shape = jnp.tile(prefac.reshape(self.nr,1,1,1), (1,self.nc,self.nz,self.nM))
        shape =  self.rho_dmb_normed_M.shape
        inarr =  (prefac_repeat_shape*self.rho_dmb_normed_M).reshape(shape[0], -1)
        H = SphericalBessel(self.r_array,lowring=True, backend="jax")
        y, G = H(inarr.T, extrap=False)
        yreturn = []
        for i in range(len(inarr[0])):
            yreturn.append(jnp.interp(k, y, G[i]))
        yreturn = jnp.array(yreturn).reshape(shape[1], shape[2], shape[3])

        return yreturn
shivampcosmo commented 1 year ago

Thanks! I had the transformation with mcfit before but the issue is that it is not natively coded up in jax. Therefore the likelihood is not differentiable. I spent a fair bit of time converting it to jax, but unsuccessfully. If you know some thing along those lines, or wanna give it a shot, might be helpful :). Did you find that it is not stable for some particular choice of angles? I did some quick checks and it was matching well. Let me redo the checks carefully. Also thanks for Tinker 10. I will add that as an option. Currently, there is no integral where low-mass end was needed, so that consistency relation was not required to be followed. But I am thinking of updating the 2-halo term, where this would be useful. Thanks!