This introduces an API for creating wrappers of Pytree nodes, such as for applying parameterizations. This allows for more concise definitions of custom parameterizations and greater flexibility in changing/composing them. These custom parameterizations can be applied using the flowjax.wrappers.unwrap function. They are automatically applied before the main bijection and distribution methods, as well as at the start of loss functions.
The high-level API remains mostly unchanged, but there are some significant breaking changes, listed approximately in order of user impact:
The filter_spec argument has been removed from fit_to_data and fit_to_variational_target. Instead, any untrainable model parts/arrays should be wrapped in flowjax.wrappers.NonTrainable. A simple example is provided in the FAQ.
The Vmap arguments have been renamed from in_axis to in_axes and in_axis_condition to in_axes_condition to align more closely with JAX/Equinox naming conventions. Additionally, Vmapin_axes must now be defined with a structure that is compatible with the unwrapped object.
The triangular array attribute in TriangularAffine has been renamed to triangular, rather than arr, and is a wrapped parameterization.
Accessing certain attributes may now return a wrapped object (e.g. affine.scale). To obtain the unwrapped version (after the constraint/parameterization is applied) flowjax.wrappers.unwrap can be applied.
Constraints can no longer be passed to the Affine and TriangularAffine__init__ methods. Instead, eqx.tree_at can be used to modify parameterizations if required.
flowjax.nn has been removed. Masked networks are now defined in the same file as the bijections they are used in (i.e., in flowjax.bijections.masked_autoregressive and flowjax.bijections.block_neural_autoregressive).
Masks in flowjax.masks now return boolean arrays rather than integers.
Many (mostly) private attributes that stored unbounded attributes (e.g., _scale, _diag) have been removed, as reparameterization is now defined through wrappers. Some non-private attributes have also been removed (unbounded_x_pos, unbounded_y_pos, and unbounded_derivatives in RationalQuadraticSpline).
Apologies if I have missed anything as this is quite a large change, but hopefully all changes are straight forward. Let me know if anything is not clear.
This introduces an API for creating wrappers of Pytree nodes, such as for applying parameterizations. This allows for more concise definitions of custom parameterizations and greater flexibility in changing/composing them. These custom parameterizations can be applied using the
flowjax.wrappers.unwrap
function. They are automatically applied before the main bijection and distribution methods, as well as at the start of loss functions.The high-level API remains mostly unchanged, but there are some significant breaking changes, listed approximately in order of user impact:
filter_spec
argument has been removed fromfit_to_data
andfit_to_variational_target
. Instead, any untrainable model parts/arrays should be wrapped inflowjax.wrappers.NonTrainable
. A simple example is provided in the FAQ.Vmap
arguments have been renamed fromin_axis
toin_axes
andin_axis_condition
toin_axes_condition
to align more closely with JAX/Equinox naming conventions. Additionally,Vmap
in_axes
must now be defined with a structure that is compatible with the unwrapped object.TriangularAffine
has been renamed totriangular
, rather thanarr
, and is a wrapped parameterization.affine.scale
). To obtain the unwrapped version (after the constraint/parameterization is applied)flowjax.wrappers.unwrap
can be applied.Affine
andTriangularAffine
__init__
methods. Instead,eqx.tree_at
can be used to modify parameterizations if required.flowjax.nn
has been removed. Masked networks are now defined in the same file as the bijections they are used in (i.e., inflowjax.bijections.masked_autoregressive
andflowjax.bijections.block_neural_autoregressive
).flowjax.masks
now return boolean arrays rather than integers._scale
,_diag
) have been removed, as reparameterization is now defined through wrappers. Some non-private attributes have also been removed (unbounded_x_pos
,unbounded_y_pos
, andunbounded_derivatives
inRationalQuadraticSpline
).Apologies if I have missed anything as this is quite a large change, but hopefully all changes are straight forward. Let me know if anything is not clear.