Open janfb opened 10 months ago
This will become even more relevant when we have a common dataloader interface and agnostic loss
functions for all SBI methods. But I am removing the hackathon label for now as it will not be done before the release.
Description:
In the current implementation of the SBI library, the
train(...)
methods in SNPE, SNRE, and SNLE exhibit a significant amount of code duplication. These methods share common functionalities such as building the neural network, resuming training, and managing the training and validation loops.This redundancy not only makes the codebase more challenging to maintain but also increases the risk of inconsistencies and bugs during updates or enhancements. To address this, we propose refactoring these methods by introducing a unified
train
function in the base class. This commontrain
function would handle the shared aspects of the training process, with specific losses and keyword arguments passed as parameters to accommodate the differences between SNPE, SNRE, and SNLE.Example
SNPE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snpe/snpe_base.py#L340-L379 SNLE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snle/snle_base.py#L214-L244 SNRE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snre/snre_base.py#L228-L260
Checklist
train
methods of SNPE, SNRE, and SNLE.train
function in the base class that accepts specific losses and other necessary arguments unique to each method.train
methods to utilize this new generic function, passing their specific requirements as arguments.We invite contributors to discuss potential strategies for this refactoring and contribute to its implementation. This effort will enhance the library's maintainability and ensure consistency across different components.
If you find other locations where we can significantly reduce code duplications, please create a new issue (i.e. #921).