danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

Softplus #85

Closed danielward27 closed 1 year ago

danielward27 commented 1 year ago

Add SoftPlus bijection, and use it as the default way to enforce positivity of scale parameters in Affine and TriangularAffine. Note this also will affect most distributions, as loc, scale distributions are implemented as Affine transformations of the standardised version.

This change was made after experiencing instability, particularly when affine transformations were parameterised by neural networks (e.g. in masked autoregressive flows and coupling flows), where the positive scale parameter can explode. See also the discussion here https://github.com/pyro-ppl/numpyro/issues/855.

Note this will introduce some breaking changes: