JohannesBuchner / UltraNest

Fit and compare complex models reliably and rapidly. Advanced nested sampling.
https://johannesbuchner.github.io/UltraNest/
Other
142 stars 30 forks source link

Float32 incompatibility #135

Closed LucaMantani closed 3 months ago

LucaMantani commented 3 months ago

Description

I am using ultranest in combination with JAX. When using the ReactiveNestedSampler everything works fine if I have this line of code: jax.config.update("jax_enable_x64", True) which casts numbers in float64. This however makes the likelihood evaluation a bit slower.

However, if this is not set (and float32 are used instead), the ultranest run crashes:

  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 2373, in run
    for result in self.run_iter(
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 2645, in run_iter
    u, p, L = self._create_point(Lmin=Lmin, ndraw=ndraw, active_u=active_u, active_values=active_values)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 1888, in _create_point
    u, v, logl, nc, quality = self._refill_samples(Lmin, ndraw, nit)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 1779, in _refill_samples
    accepted = self.tregion.inside(v)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "ultranest/mlfriends.pyx", line 1638, in ultranest.mlfriends.WrappingEllipsoid.inside
  File "ultranest/mlfriends.pyx", line 877, in ultranest.mlfriends._inside_ellipsoid
ValueError: Buffer dtype mismatch, expected 'float_t' but got 'float'

If instead I turn on the SliceSampler, things seem to be working also with float32.

LucaMantani commented 3 months ago

Actually, I realised that while with the SliceSampler runs, it ultimately crashes when writing results:

  File "/Users/luca/Applications/colibri/colibri/ultranest_fit.py", line 136, in ultranest_fit
    ultranest_result = sampler.run(**ns_settings["Run_settings"])
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 2373, in run
    for result in self.run_iter(
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 2751, in run_iter
    self._update_results(main_iterator, saved_logl, saved_nodeids)
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/site-packages/ultranest/integrator.py", line 2885, in _update_results
    json.dump(results_simple, f, indent=4)
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/__init__.py", line 179, in dump
    for chunk in iterable:
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/encoder.py", line 432, in _iterencode
    yield from _iterencode_dict(o, _current_indent_level)
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/encoder.py", line 406, in _iterencode_dict
    yield from chunks
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/encoder.py", line 439, in _iterencode
    o = _default(o)
        ^^^^^^^^^^^
  File "/Users/luca/miniforge3/envs/colibri-dev/lib/python3.12/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type float32 is not JSON serializable

It seems related to JSON more than ultranest but I think the issue is that json doesn't know how to serialise float32 by default. This is probably an easy fix though, one can give an Encoder:

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.float32):
            return float(obj)
        return json.JSONEncoder.default(self, obj)

a = np.float32(1)
print(json.dumps(a, cls=NumpyEncoder)) 

This seems to work.

JohannesBuchner commented 3 months ago

This may be a stupid question, but could you instead make your likelihood convert from jax arrays to a numpy array of the right dtype? Are you seeing a huge performance drop if you do that?

The first error you gave is because you cannot easily switch between 32 and 64 bytes given that cython is compiled and for one kind of floats. With a step sampler you are not using any cython code (which would again not work if you were to use a popstepsampler).

JohannesBuchner commented 3 months ago

I think myresult.astype(np.float_t) should do it

LucaMantani commented 3 months ago

This may be a stupid question, but could you instead make your likelihood convert from jax arrays to a numpy array of the right dtype? Are you seeing a huge performance drop if you do that?

When using dtype float64 the likelihood call is ~2 times slower on average. Maybe I did not understand your suggestion, but if I were to cast the jax arrays to np arrays, I wouldn't be able to use the jit compilation which is really impactful in terms of performance.

JohannesBuchner commented 3 months ago

I mean

def mylikelihood(params):
     1) convert params from numpy array to jax array
     2) compute likelihood with jax
     3) convert likelihood to numpy (with .astype(np.float_t)) and return it

assuming a transform written in numpy (otherwise it also needs to transform its return value back to a numpy array too).

LucaMantani commented 3 months ago

Ok interesting, that indeed seems to be working as long as I do not jit compile the prior_transform. If I compile it, then it does not work and the error message is the one in the first post.

Actually that's the only problem, I do not need to cast the likelihood into numpy if I do not jit compile the prior.

JohannesBuchner commented 3 months ago

If so, then maybe we don't need PR #136 since it does not address all the other things that will not work (MLFriends (no stepsampler) and popstepsampler).

LucaMantani commented 3 months ago

It is true that it does not fix them but it did not have that purpose. It has the purpose of making the json.dump a bit more robust in the end, which incidentally allows to work with float32 for the SliceSampler, but that's basically a gift.

JohannesBuchner commented 3 months ago

It seems that this issue can be closed for now as you have found a workaround. We could add some testing whether the likelihood/transform return value is a jax array in the code, but that seems a bit specific, so alternatively. Some tutorials for using jax models with ultranest would be useful (e.g. as a blog post or in the docs) to so other people do not encounter the same problem.