Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

Setting config.update("jax_enable_x64", True) produces crash #154

Closed amunozj closed 5 months ago

amunozj commented 5 months ago

Describe the bug I'm using NumPyro's JaxNS wrapper. In the forum I asked about running JaxNS in 64 bit. See #153 In that discussion the recommendation was to set the 64-bit flags. I'm currently setting:

from jax import config
config.update("jax_enable_x64", True)
import numpyro
numpyro.enable_x64()

Expected behavior I expected the sampler to run in 64-bit

Observed behavior The sampler crashes because of a type mismatch:

TypeError: body_fun output and input must have identical types, got
CarryType(state=StaticStandardNestedSamplerState(key='ShapedArray(uint32[2])', next_sample_idx='ShapedArray(int32[])', sample_collection=StaticStandardSampleCollection(sender_node_idx='ShapedArray(int32[10000])', log_L='ShapedArray(float32[10000])', U_samples='ShapedArray(float32[10000,2])', num_likelihood_evaluations='ShapedArray(int32[10000])', phantom='ShapedArray(bool[10000])'), front_idx='ShapedArray(int32[50])'), termination_register=TerminationRegister(num_samples_used='ShapedArray(int32[])', evidence_calc=EvidenceCalculation(log_L='ShapedArray(float32[])', log_X_mean='ShapedArray(float32[])', log_X2_mean='ShapedArray(float32[])', log_Z_mean='ShapedArray(float32[])', log_ZX_mean='ShapedArray(float32[])', log_Z2_mean='ShapedArray(float32[])', log_dZ_mean='ShapedArray(float32[])', log_dZ2_mean='ShapedArray(float32[])'), evidence_calc_with_remaining=EvidenceCalculation(log_L='ShapedArray(float32[])', log_X_mean='ShapedArray(float32[])', log_X2_mean='ShapedArray(float32[])', log_Z_mean='ShapedArray(float32[])', log_ZX_mean='ShapedArray(float32[])', log_Z2_mean='ShapedArray(float32[])', log_dZ_mean='ShapedArray(float32[])', log_dZ2_mean='ShapedArray(float32[])'), num_likelihood_evaluations='DIFFERENT ShapedArray(int64[]) vs. ShapedArray(int32[])', log_L_contour='ShapedArray(float32[])', efficiency='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', plateau='ShapedArray(bool[])')).

The problem seems to be with these two variables:

...
TerminationRegister(
...
num_likelihood_evaluations='DIFFERENT ShapedArray(int64[]) vs. ShapedArray(int32[])', 
efficiency='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'
...
)

Removing the 64-bit flags removes the crash. Any suggestions as to how to fix?

JAXNS version jaxns==2.4.10

Joshuaalbert commented 5 months ago

In general, this sometimes happens when a rouge operation creates a float32 when float64 is enabled. This could happen for instance if the below is not at the start of the python program.

from jax import config
config.update("jax_enable_x64", True)

Can you confirm the location of the above?

If JAXNS code loads before this then internal dtypes can be set to 32-bit after which enabling 64-bit would allow constant conversion to create mis-matches. In general this can probably be resolved by carefully controlling the input/output dtype invariants but that will take some work and testing.

amunozj commented 5 months ago

Thank you so much @Joshuaalbert ! making sure the definition of the x64 config is set, as early as possible, does solve the issue!