danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Wrap #142

Closed danielward27 closed 6 months ago

danielward27 commented 6 months ago

This introduces an API for creating wrappers of Pytree nodes, such as for applying parameterizations. This allows for more concise definitions of custom parameterizations and greater flexibility in changing/composing them. These custom parameterizations can be applied using the flowjax.wrappers.unwrap function. They are automatically applied before the main bijection and distribution methods, as well as at the start of loss functions.

The high-level API remains mostly unchanged, but there are some significant breaking changes, listed approximately in order of user impact:

Apologies if I have missed anything as this is quite a large change, but hopefully all changes are straight forward. Let me know if anything is not clear.