Joshuaalbert / jaxns

Probabilistic Programming and Nested sampling in JAX
https://jaxns.readthedocs.io/
Other
135 stars 9 forks source link

Add unpacked U-space parameterisation to remove an unnecessary concatenation. #177

Open Joshuaalbert opened 1 month ago

Joshuaalbert commented 1 month ago

Is your feature request related to a problem? Please describe. Currently we operate on flat [0,1]^D representation, however this is artificially imposed for convienience and requires ravelling and concatenating to form, and indexing and reshaping to form RVs. We can speed up the framework by removing this step.

This would speed up anything that makes use of jaxns.framework.* including gradient-based optimisation routines using jaxns for the model definition e.g. DSA2000.