Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Apache License 2.0
145 stars 10 forks source link

Lennard-Jones example takes a very long time and error estimates seem off #187

Closed Binbose closed 1 month ago

Binbose commented 2 months ago

Describe the bug I tried to run the Lennard-Jones example from the documentation. With all the default settings exactly copied, the algorithm doesn't finish even after an hour of running on an RTX2070 GPU. I then tried to reduce the max_samples to 1e3, and it finished, but the results seem strange:

--------
Termination Conditions:
Reached max samples
--------
likelihood evals: 45916
samples: 960
phantom samples: 0
likelihood evals / sample: 47.8
phantom fraction (%): 0.0%
--------
logZ=9.58 +- 0.56
H=-7.68
ESS=1
--------
x[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x[0]: 2.0 +- 1.6 | 0.2 / 3.5 / 3.5 | 3.5 | 3.5
x[1]: 2.2 +- 1.1 | 0.8 / 3.1 / 3.1 | 3.1 | 3.1
x[2]: 1.38 +- 0.52 | 1.07 / 1.07 / 2.41 | 1.07 | 1.07
x[3]: 1.11 +- 0.61 | 0.74 / 0.89 / 2.6 | 0.89 | 0.89
x[4]: 2.0 +- 1.0 | 1.1 / 1.1 / 3.5 | 1.1 | 1.1
x[5]: 1.8 +- 1.2 | 0.3 / 2.9 / 2.9 | 2.9 | 2.9
x[6]: 2.4 +- 1.2 | 0.5 / 3.3 / 3.3 | 3.3 | 3.3
x[7]: 0.9 +- 1.1 | 0.0 / 0.0 / 2.4 | 0.0 | 0.0
x[8]: 2.19 +- 0.96 | 0.58 / 1.89 / 3.57 | 1.89 | 1.89
x[9]: 0.91 +- 0.27 | 0.65 / 0.88 / 1.09 | 0.88 | 0.88
x[10]: 2.0 +- 1.0 | 1.1 / 1.1 / 3.5 | 1.1 | 1.1
x[11]: 1.8 +- 1.1 | 0.3 / 2.8 / 2.8 | 2.8 | 2.8
--------

Rerunning it with 5e3 samples, I get this

--------
Termination Conditions:
Reached max samples
--------
likelihood evals: 2312976
samples: 4800
phantom samples: 0
likelihood evals / sample: 481.9
phantom fraction (%): 0.0%
--------
logZ=70.36 +- 0.88
H=-25.51
ESS=0
--------
x[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x[0]: 1.10405695 +- 1.2e-07 | 1.10405707 / 1.10405707 / 1.10405707 | 1.10405707 | 1.10405707
x[1]: 2.2 +- 0.0 | 2.2 / 2.2 / 2.2 | 2.2 | 2.2
x[2]: 1.9880091 +- 1.2e-07 | 1.98800921 / 1.98800921 / 1.98800921 | 1.98800921 | 1.98800921
x[3]: 1.10416484 +- 1.2e-07 | 1.10416496 / 1.10416496 / 1.10416496 | 1.10416496 | 1.10416496
x[4]: 2.2 +- 0.0 | 2.2 / 2.2 / 2.2 | 2.2 | 2.2
x[5]: 2.0 +- 0.0 | 2.0 / 2.0 / 2.0 | 2.0 | 2.0
x[6]: 2.1 +- 0.0 | 2.1 / 2.1 / 2.1 | 2.1 | 2.1
x[7]: 3.15060735 +- 2.4e-07 | 3.15060759 / 3.15060759 / 3.15060759 | 3.15060759 | 3.15060759
x[8]: 3.1 +- 0.0 | 3.1 / 3.1 / 3.1 | 3.1 | 3.1
x[9]: 1.0 +- 0.0 | 1.0 / 1.0 / 1.0 | 1.0 | 1.0
x[10]: 3.74954653 +- 2.4e-07 | 3.74954677 / 3.74954677 / 3.74954677 | 3.74954677 | 3.74954677
x[11]: 1.6 +- 0.0 | 1.6 / 1.6 / 1.6 | 1.6 | 1.6
--------

What I find particularly strange is that the logZ values jump so much from the first to the second run. Also, I would expect the error bars from the first and second run to overlap, but this is not the case. Is this the expected behaviour? Similar things happen if I simplify the example to only two particles.

Joshuaalbert commented 2 months ago

Thanks for the interest in Lennard-Jones. This example is still in beta. I met some folks at MaxEnt who do this line of research and thought to make an example with jaxns, but didn't have time to complete it. I think I'll explicitly label examples as beta if still a work in progress. I'll ping you here when it's done.

Binbose commented 2 months ago

Hey, thank you for the quick reply! Maybe it isn't a bug (except for the error estimates of logZ, those seem off), and I just didn't wait long enough. What is roughly the expected runtime for the problem setup from the example on a single GPU?

Joshuaalbert commented 1 month ago

It runs much faster now. Upgrade to latest.

Joshuaalbert commented 1 month ago

@Binbose I will close for now since it's resolved.