Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

Use tfp.math.associative_scan for associative binary ops #163

Closed Joshuaalbert closed 2 months ago

Joshuaalbert commented 3 months ago

Is your feature request related to a problem? Please describe. Computing evidence over samples is an associative binary op. And likely more than that buried in the computations.

Can simply replace cumulative_op_static with an all-prefix sum version. This will speed up:

  1. sample_evidence
  2. m_step of EvidenceMaximisation
  3. compute_enclosed_prior_volume
  4. compute_evidence_stats
  5. count_crossed_edges
  6. Computing evidence_calc_with_remaining in _inter_sync_shrinkage_process
  7. get_sample_from_seed in UniDimSliceSampler and MultiDimSliceSampler
Joshuaalbert commented 3 months ago

Note, compile times could grow though