This pull request refactors the training scripts and bumps to version 10
Breaking changes:
Changes to fit_to_data:
Removed clip_norm as an argument to fit_to_data . This was done because 1) the argument lists were already quite long, 2) adding gradient clipping can still be achieved by passing an optimizer, 3) clipping seems to be less important now Affine uses softplus as a positivity constraint.
Now allows passing of a loss function. This allows the same training script to be used for both SNPE/contrastive learning, and for standard maximum likelihood fitting of distributions.
Changes in default parameters (batch_size 256->100 and max_epochs 50->100),
fit_to_variational_target has similarly had the clip_norm argument removed.
ElboLoss has been moved to flowjax.train.losses and losses for fit_to_data by contrastive learning and maximum likelihood have been added here.
step function moved to flowjax.train.train_utils and used by both fit_to_data and fit_to_variational_target.
This pull request refactors the training scripts and bumps to version 10
Breaking changes:
Changes to
fit_to_data
:clip_norm
as an argument tofit_to_data
. This was done because 1) the argument lists were already quite long, 2) adding gradient clipping can still be achieved by passing an optimizer, 3) clipping seems to be less important nowAffine
uses softplus as a positivity constraint.batch_size
256->100 andmax_epochs
50->100),fit_to_variational_target
has similarly had theclip_norm
argument removed.ElboLoss
has been moved toflowjax.train.losses
and losses forfit_to_data
by contrastive learning and maximum likelihood have been added here.step
function moved toflowjax.train.train_utils
and used by bothfit_to_data
andfit_to_variational_target
.Removed deprecated
AdditiveLinearCondition