Closed eelregit closed 1 year ago
Apparently JAX doesn't like having another parameter beyond what NumPy already has. It already added a length
parameter necessary for it to be jittable. It seems natural to me, though unfortunately not to JAX, to have dtype
as well because of their different type promotion treatment.
I didn't realize that XLA doesn't have mixed precision scatter add. So upcasting weights, as a temporary solution, will have the same effects as the above JAX PR.
When the density field is large, single precision can cause problems in JAX
bincount
. I add a version that upcast during scatter, and try to add it to JAX in https://github.com/google/jax/pull/18393, so that we can use that instead if it gets merged and released.