Deprecate fit_to_variational_target, and add fit_to_key_based_loss.
Reason:
It was set to return the parameters when the minimum loss was reached by default. This provides some protection against instability in training, but is bad for two reasons: 1) for very stochastic losses, it can result in the "best" parameters based on the minimum loss being far the actual best model 2) some objectives give useful gradients, without any expectation to minimize the "loss", e.g. for some contrastive and adversarial approaches. I have also renamed the function, to reflect the more general utility of the function (doesn't have to use flowjax distributions at all).
We also move the training loops into the same file, meaning the module flowjax.train.data_fit is deprecated.
Deprecate
fit_to_variational_target
, and addfit_to_key_based_loss
.Reason:
We also move the training loops into the same file, meaning the module
flowjax.train.data_fit
is deprecated.