Open shawnwwimer opened 1 day ago
There is a significant accuracy loss for diffraction calculations when using JAX float32 (which is the default when importing JAX)
Can you check if enabling JAX x64 solves your problem?
You need to add the following lines before importing diffractsim:
import jax
jax.config.update("jax_enable_x64", True)
Yes, that does fix it. The difference is definitely the precision. I brought it up here simply because I didn't see a related issue here and I was surprised to find out it was due to default JAX behavior.
I was using JAX for the backend and couldn't get some simulations to agree with an analytical form. I noticed that changing the backend to the CPU fixed this problem and found that JAX uses a low precision by default for some operations:
I don't have the time to test right now, but from that second issue: "Try setting jax.default_matmul_precision to float32". If anybody runs into a similar problem this may be the cause. If so, it may be good to note it in the readme.