rafael-fuente / diffractsim

✨🔬 A flexible diffraction simulator for exploring and visualizing physical optics.
https://rafael-fuente.github.io/simulating-diffraction-patterns-with-the-angular-spectrum-method-and-python.html
Other
751 stars 92 forks source link

Numerical error with JAX #60

Open shawnwwimer opened 1 day ago

shawnwwimer commented 1 day ago

I was using JAX for the backend and couldn't get some simulations to agree with an analytical form. I noticed that changing the backend to the CPU fixed this problem and found that JAX uses a low precision by default for some operations:

I don't have the time to test right now, but from that second issue: "Try setting jax.default_matmul_precision to float32". If anybody runs into a similar problem this may be the cause. If so, it may be good to note it in the readme.

rafael-fuente commented 1 day ago

There is a significant accuracy loss for diffraction calculations when using JAX float32 (which is the default when importing JAX)

Can you check if enabling JAX x64 solves your problem?

You need to add the following lines before importing diffractsim:

import jax
jax.config.update("jax_enable_x64", True)
shawnwwimer commented 20 hours ago

Yes, that does fix it. The difference is definitely the precision. I brought it up here simply because I didn't see a related issue here and I was surprised to find out it was due to default JAX behavior.