Closed joshuawchen closed 5 months ago
Hello, thanks for your offer!
Within fit_to_data
, we already perform cross validation and use it to avoid overfitting. For variational approaches, cross validation becomes a bit nuanced. For some models (e.g. amortised variational inference approaches with multiple observations), we could use predictive performance on held out data to avoid overfitting, and improve generalisation to new points. However, for a lot of models it does not make sense - e.g. for any model with a single observation. For this reason, I'd initially lean against trying to add cross validation into the variational inference training.
Hi! I have a private fork of flowjax, in which I implement cross-validation for the test/train split of the data, so as to not overfit to the test data (eg for fit_to_data/fit_to_variational_target). It works quite well. Would you be open to me contributing to flowjax/putting up a pull request for this?