Closed mtagliazucchi closed 1 hour ago
jit
doesn't guarantee that the intermediates and the outputs will be exactly equal to the ones without jit
. I think you might need to use jnp.abs(sig2) <= eps
to be robust to that.
jit
doesn't guarantee that the intermediates and the outputs will be exactly equal to the ones withoutjit
. I think you might need to usejnp.abs(sig2) <= eps
to be robust to that.
Thanks! This is helpful
Description
Hi everyone,
I noticed a very strange behaviour of a function when Jitted.
Here's the code:
The non Jitted function prints the correct result (0.), while the Jitted function seems to return the second function of
lax.cond
. The problem seems solved if I change the boolean condition oflax.cond
intosig2 <= 0.0
, but I'm not sure this condition holds in general.Can someone explain why this occur and how to solve this problem? However, I need to use the function get_neff within another Jitted function.
System info (python version, jaxlib version, accelerator, etc.)
System info: