Change in the step function signature in flowjax.train.train_utils. This was (and still is) undocumented in the API.
Change in losses signature to always have params first (instead of static). Even in custom loss functions, this often wouldn't break anything as eqx.combine will usually be symmetric in our use cases i.e. combine(pytree1, pytree2) will equal combine(pytree2, pytree1), however, code involving loss functions should be updated to pass them in the right order.
Breaking changes:
step
function signature inflowjax.train.train_utils
. This was (and still is) undocumented in the API.eqx.combine
will usually be symmetric in our use cases i.e.combine(pytree1, pytree2)
will equalcombine(pytree2, pytree1)
, however, code involving loss functions should be updated to pass them in the right order.