Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
141 stars 9 forks source link

Performance degradation as max_samples increases #117

Closed Joshuaalbert closed 8 months ago

Joshuaalbert commented 8 months ago

Describe the bug In 2.3 I've refactored the algorithm for improvements but there's a loop where I'm pushing samples between iterations. This must be where slow down is happening. I'll profile it, and determine for sure. I don't think the entire state is needed per iteration so it should be fairly trivial to resolve the slow down. Slow down happens around max_samples>1e6.

JAXNS version

2.3.x

Joshuaalbert commented 8 months ago

perfetto_trace_1e8_constant_likelihood.json.gz perfetto_trace_1e6_constant_likelihood.json.gz perfetto_trace_1e3_constant_likelihood.json.gz

Joshuaalbert commented 8 months ago

For others, they can run the benchmark in benchmarks/run_benchmark.sh

Joshuaalbert commented 8 months ago

Timings on python 3.11:

2.3.0
Avg. time taken: 7.17404 seconds.
The best 3 of 10 runs took 7.16335 seconds.

2.3.1
Avg. time taken: 7.20859 seconds.
The best 3 of 10 runs took 7.18449 seconds.

2.3.2
Avg. time taken: 7.38937 seconds.
The best 3 of 10 runs took 7.15356 seconds.

2.3.4
Avg. time taken: 7.20582 seconds.
The best 3 of 10 runs took 7.16257 seconds.

2.4.0
Avg. time taken: 7.19476 seconds.
The best 3 of 10 runs took 7.16229 seconds.

2.4.1
Avg. time taken: 7.20860 seconds.
The best 3 of 10 runs took 7.18259 seconds.

We still see big timings for a model that only requires 30 samples. Something is wrong here, and it must be simple.

Joshuaalbert commented 8 months ago

Initial hypothesis is wrong. Only pushing the front from iteration to iteration has no impact. Somehow the runtime scales with max_samples. Let's list all the ops that operate on the whole state.sample_collection:

Seems like the terminations conditions needs to be made into an accumulant.