Closed colehaus closed 1 month ago
Hi - thanks for the report! This looks like a JAX issue, and would be better reported at http://github.com/google/jax/. There is nothing that can be done in this repository to affect the behavior of JAX's JIT compiler. Thanks!
Okay, thanks, I'll move it over there. I wasn't sure if maybe there was some particularity with the bfloat16
implementation that was affecting the optimizations the JIT compiler thought were safe (since the errors don't occur for e.g. float16).
I have some stochastic rounding code and uncovered a bug when trying to use the code like the following:
With bfloat16, the final line prints
True
even though it's clear from the preceding line that not all errors ought to be 0.np.float32
does not have this behavior.Here are some lowering and compilation outputs, if that happens to be helpful. First bfloat16 and then float32:
Here's the info from
jax.print_environment_info()
:(Let me know if this is a better fit for the main JAX repo.)