JeffersonLab / jlab_datascience_core

2 stars 1 forks source link

Model.fit VS Trainer.fit #19

Open ahmedmohammed107 opened 8 months ago

ahmedmohammed107 commented 8 months ago

Didact Scenario:

As Diana was testing the new DIDACT workflow, she registered the model to MLflow with 1 epoch to make sure the code is running. She then loaded the model from MLflow and changed the number of epochs in the configuration file to actually train the model. But since the model is registered with the number of epochs to train, she had to override model['training_config']['num_epochs'].

Problems:

1) The model is provided with a lot of information that has nothing to do with the model itself. If the model is intended to be used for inference for example, it wouldn't make sense to provide the learning rate in the constructor. It can be provided as an argument to the model.fit method but there could be a lot of other parameters that need to be provided as well. 2) Additionally, the fit method is where logging is performed and snapshots are saved. Again this adds to the Model class additional behavior that is irrelevant to the model itself.

Proposed Solution ===> trainer.fit instead of model.fit:

It is widely adopted in PyTorch to separate the model from the training. The model would only have init and forward (or call) methods. The Trainer class would do all the heavy lifting. It takes the model, training/validation data, and all training configs (learning rate, optimizer_config, scheduler_config, early stopping config). It also does the logging and knows where to save snapshots. It can serve as the interface with MLflow by saving: 1) parameters before training starts, 2) Metrics and artifacts after training training is complete.

schr476 commented 7 months ago

@ahmedmohammed107, please focus on the parser first.