danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Cross-validation/Shuffling Test/Train split option in fitting utility functinos #146

Closed joshuawchen closed 5 months ago

joshuawchen commented 5 months ago

Hi! I have a private fork of flowjax, in which I implement cross-validation for the test/train split of the data, so as to not overfit to the test data (eg for fit_to_data/fit_to_variational_target). It works quite well. Would you be open to me contributing to flowjax/putting up a pull request for this?

danielward27 commented 5 months ago

Hello, thanks for your offer!

Within fit_to_data, we already perform cross validation and use it to avoid overfitting. For variational approaches, cross validation becomes a bit nuanced. For some models (e.g. amortised variational inference approaches with multiple observations), we could use predictive performance on held out data to avoid overfitting, and improve generalisation to new points. However, for a lot of models it does not make sense - e.g. for any model with a single observation. For this reason, I'd initially lean against trying to add cross validation into the variational inference training.