Closed twhentschel closed 10 months ago
Thanks for the catch!
I pretty much always use 64 bit but I know a lot of people don't so I've reduced the tolerances on the example and also made the default epsilon values vary with the dtype so it should work well for both 32 and 64 bit precision.
Hi @f0uriest,
Thanks for creating this neat package! For the example on the Readme, the assert statement is failing for me. The fix was to make JAX work with 64 bit precession with the lines
An alternative would be to remove the first assert statement and note that by default JAX using 32bit for floats. The second check works just fine but checking the precision out to 1e-14 might be too strict for other examples.
If you think this might be a bug, I'd be happy to contribute a PR for either of these fixes if you're interested.