Add standard training loop for torch models, while iterating over tdfs datasets. The API is the same as that of train_keras_app function.
Refactor train_cifar10 with the add of a parser to easily train using either keras or torch model.
Note: A function allow_growth is called at the import of the oodeel.models.training_funcs.torch_models module. It prevents tensorflow from allocating the totality of the GPU when a tfds dataset is loaded, so that some VRAM remain free to train torch models. Thus, the torch_models module absolutely needs to be loaded BEFORE calling data_handler.load_tfds.
Features:
train_keras_app
function.train_cifar10
with the add of a parser to easily train using either keras or torch model.Note: A function
allow_growth
is called at the import of theoodeel.models.training_funcs.torch_models
module. It prevents tensorflow from allocating the totality of the GPU when a tfds dataset is loaded, so that some VRAM remain free to train torch models. Thus, thetorch_models
module absolutely needs to be loaded BEFORE callingdata_handler.load_tfds
.Run training on CIFAR-10: