sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
599 stars 152 forks source link

code repetition in train methods #922

Open janfb opened 10 months ago

janfb commented 10 months ago

Description:

In the current implementation of the SBI library, the train(...) methods in SNPE, SNRE, and SNLE exhibit a significant amount of code duplication. These methods share common functionalities such as building the neural network, resuming training, and managing the training and validation loops.

This redundancy not only makes the codebase more challenging to maintain but also increases the risk of inconsistencies and bugs during updates or enhancements. To address this, we propose refactoring these methods by introducing a unified train function in the base class. This common train function would handle the shared aspects of the training process, with specific losses and keyword arguments passed as parameters to accommodate the differences between SNPE, SNRE, and SNLE.

Example

SNPE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snpe/snpe_base.py#L340-L379 SNLE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snle/snle_base.py#L214-L244 SNRE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snre/snre_base.py#L228-L260

Checklist

We invite contributors to discuss potential strategies for this refactoring and contribute to its implementation. This effort will enhance the library's maintainability and ensure consistency across different components.

If you find other locations where we can significantly reduce code duplications, please create a new issue (i.e. #921).

janfb commented 4 months ago

This will become even more relevant when we have a common dataloader interface and agnostic loss functions for all SBI methods. But I am removing the hackathon label for now as it will not be done before the release.