Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

`shard_map` parallel nested sampling #172

Open Joshuaalbert opened 2 months ago

Joshuaalbert commented 2 months ago

Implement sharded parallel nested sampling applicable to multi-node, multi-device with SIMD and sharded computation.

Joshuaalbert commented 3 weeks ago

Following discussion at MaxEnt 2024 with John Skilling JAXNS's parallel implementation will be updated. It will perform parallel NS by discarding M points per step and replacing them all in parallel. This will impose a global synchronisation on the speed of the slowest sample replacement before the next shell can be discarded but ultimately this will still perform well on distributed hardward and make use of accelerator meshes. The sharded axis shall be the frontier of samples (representing the live points). This will improve on the current parallel NS which uses totally independent samplers (at reduced resolution), but uses a secondary phase to align the depths. Our sample tree implementation will be preserved for posterity and dynamic refinement, but I'll ultimately choose whatever structure is the most performant, in line with JAXNS's mission.