Offload most of the training boilerplate code (aka the early stopping, model saving, GPU distribution, training loops, etc...) to its own class. This will help to ensure that the science code (aka how the computation across each layer is performed, which optimizers to use, etc...) can be easily modified and run without the hassle of changing a lot of code.
Additionally, it seems cumbersome to have to set the model to the callback before training. If the user forgets, then the model's learned weights cannot be saved properly (PyTorch will throw an error) when using the ModelCheckpoint callback.
I think it makes more sense to have a trainer class associate the model with all the specified callbacks. This will handle the issue of not having to set the model to each associated callback.
Offload most of the training boilerplate code (aka the early stopping, model saving, GPU distribution, training loops, etc...) to its own class. This will help to ensure that the science code (aka how the computation across each layer is performed, which optimizers to use, etc...) can be easily modified and run without the hassle of changing a lot of code.
Additionally, it seems cumbersome to have to set the model to the callback before training. If the user forgets, then the model's learned weights cannot be saved properly (PyTorch will throw an error) when using the ModelCheckpoint callback.
I think it makes more sense to have a trainer class associate the model with all the specified callbacks. This will handle the issue of not having to set the model to each associated callback.