Open null-a opened 5 years ago
Using double precision will cause perf issues with #14, but we haven't found it to significantly impact performance on the CPU. Also, NumPyro (JAX) I think will cast everything to single precision unless an environment flag is turned on to enable XLA to operate in double precision. Have you found numerical instability issues with single precision? We have found a few such cases in Pyro, but on closer inspection it turned out that these were due to numerical instability in our distribution implementations. Fixing that allowed us to get good results even with single precision. For NumPyro most instances of numerical imprecision arise due to XLA's fast-math mode which can be turned off using os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"
.
Have you found numerical instability issues with single precision?
Not that I'm aware of. I agree that choosing to use single precision everywhere seems fine. I'm happy to have this clear in my mind, thanks.
Though as you hint at, this isn't likely to make a noticeable difference (to e.g. performance on the cpu) so switching to single precision for design matrices is probably low priority at present.
@neerajprad: Am I right in saying that you think it might be worth us having the ability to optionally use double precision?
Currently we use a mixture of both. e.g. Design matrices are initially coded as doubles. This carries over to the data used for inference in NumPyro, but these are converted to single precision when using Pyro. OTOH, both back ends represent samples with single precision.