eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
71 stars 18 forks source link

Fix powspec round-off error problem and add jit #23

Closed eelregit closed 11 months ago

eelregit commented 11 months ago

When the density field is large, single precision can cause problems in JAX bincount. I add a version that upcast during scatter, and try to add it to JAX in https://github.com/google/jax/pull/18393, so that we can use that instead if it gets merged and released.

eelregit commented 11 months ago

Apparently JAX doesn't like having another parameter beyond what NumPy already has. It already added a length parameter necessary for it to be jittable. It seems natural to me, though unfortunately not to JAX, to have dtype as well because of their different type promotion treatment.

eelregit commented 11 months ago

I didn't realize that XLA doesn't have mixed precision scatter add. So upcasting weights, as a temporary solution, will have the same effects as the above JAX PR.