DifferentiableUniverseInitiative / jax_cosmo

A differentiable cosmology library in JAX
MIT License
171 stars 36 forks source link

Error in `background.radial_comoving_distance` ? #128

Open MickaelRigault opened 6 days ago

MickaelRigault commented 6 days ago

Hello guys,

I'm sorry but I have a stupid question. I have the feeling that there is an issue of h in the function background.radial_comoving_distance: The doc says it is returning distances in Mpc/h while I think it returns it in Mpc * h.

Here is an example comparing to astropy.cosmology

import jax_cosmo as jcosmo
from astropy.cosmology import Planck15 as astropy_planck15
jc_planck15 = jcosmo.Planck15()

# jax-cosmo
z = jnp.asarray([0.5, 2.5])
dist_mpch = jcosmo.background.radial_comoving_distance(jref_cosmo, a=jcosmo.utils.z2a(z))
# so
dist_mpc_expected = dist_mpch * jc_planck15.h  # Mpc/h * h => Mpc
dist_mpc_devided = dist_mpch / jc_planck15.h  # Mpc/h / h => Mpc/h^2

# astropy
dist_mpc_astropy = astropy_planck15.comoving_distance(z).value # retuns in Mpc

print(f"astropy: {dist_mpc_astropy}")
print(f"expected dist_mpch * jc_planck15.h: {dist_mpc_expected}")
print(f"dist_mpch / jc_planck15.h: {dist_mpc_devided}")
astropy: [1945.56126208 5971.73020615]
expected dist_mpch * jc_planck15.h: [ 893.26105 2744.334  ]
dist_mpch / jc_planck15.h: [1946.6506 5980.625 ]
MickaelRigault commented 6 days ago

See also:

import jax.numpy as jnp
import numpy as np
import jax_cosmo as jcosmo
import matplotlib.pyplot as plt

from astropy.cosmology import Planck15 as astropy_planck15
jc_planck15 = jcosmo.Planck15()

z = jnp.linspace(0.001, 1.5, 100).astype("float32")

fig, ax = plt.subplots()

dist_mpch = jcosmo.background.radial_comoving_distance(jc_planck15, a=jcosmo.utils.z2a(z)) # in Mpc/h (??) or Mpc*h (!!)

ax.plot(z, dist_mpch / jc_planck15.h, label="dist_mpch / jref_cosmo.h", ls=":", lw=5)
ax.plot(z, dist_mpch * jc_planck15.h, label="dist_mpch * jref_cosmo.h", ls="--", )
ax.plot(z, astropy_planck15.comoving_distance(np.asarray(z)).value, label="astropy comoving_distance", ls="-")
ax.legend()

ax.set_ylabel("dist [Mpc]")
ax.set_xlabel("redshift")

image