eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
70 stars 19 forks source link

Fix powspec round-off error problem and add jit #23

Closed eelregit closed 1 year ago

eelregit commented 1 year ago

When the density field is large, single precision can cause problems in JAX bincount. I add a version that upcast during scatter, and try to add it to JAX in https://github.com/google/jax/pull/18393, so that we can use that instead if it gets merged and released.

eelregit commented 1 year ago

Apparently JAX doesn't like having another parameter beyond what NumPy already has. It already added a length parameter necessary for it to be jittable. It seems natural to me, though unfortunately not to JAX, to have dtype as well because of their different type promotion treatment.

eelregit commented 1 year ago

I didn't realize that XLA doesn't have mixed precision scatter add. So upcasting weights, as a temporary solution, will have the same effects as the above JAX PR.