Open dionhaefner opened 3 months ago
@dionhaefner not sure what exactly the problem with jitted function here, ~but I think we can rewrite calc_topo_kernel
function in a bit more clear way and the check will be also passing:~
That's because your rewritten function isn't the same as the original, and you're executing it twice on the same device.
The issue is that this function gives wrong results on GPU but not on CPU.
yes, sorry, I got wrong the ops in the function. Non-jitted version has no issue between gpu / cpu results but I confirm that jitted version has some discrepancy.
Description
I noticed in Veros that we would sometimes observe solver divergences on GPU, but not on CPU. I've isolated the problem to a kernel that gives different results based on the backend across several different machines.
Reproducer:
This prints:
Here's the pickle file containing the inputs to the function:
jaxbug.zip
System info (python version, jaxlib version, accelerator, etc.)