Model weights can be initialized from an arbitrary file in order to continue the training process (possibly with new params). In this connection:
there are two paths associated with a model:
model_load_path – path to a file for initializing model's weihts. This path should be passed to train script tools/train.py via --init_weights parameter. If this path is not specified, CakeChatModel will construct it based on config.py file and will try to initialize its weights from this file.
model_save_path – path to a file for saving model's weights. This path is constructed by CakeChatModel.
NN_MODEL_PREFIX – model's name always includes prefix, which can be general ('cakechat' for example), or specific for the given experiment.
Trained models are accessed via get_trained_model() and cached thanks to @cached decorator
Separate module for data_types – cakechat/utils/data_types.py
Following entities are incorporated into CakeChatModel:
is_reverse_model flag
params_str – string with main model parameters
save_model() and delete_model() functions
log_predictions() can simultaneously generate predictions for different PREDICTION_MODES, not just for one
All datasets (train, train_subset, context_free_val, context_sensitive_val, context_sensitive_val_subset) are accessed via get_datasets() and passed via a namedTuple – DatasetsCollection
All statistics data about train process is collected and passed via a namedTuple TrainStats
Logging logic is implemented in a dedicated function _log_train_info_for_one_batch()
Analytics logic (metrics calculation) is implemented in a dedicated function _analyse_model_performance()
Constant MAX_VAL_LINES_NUM defines max lines number of validation set to be used for metrics calculation – useful for testing and debugging purposes.
Updates:
Model weights can be initialized from an arbitrary file in order to continue the training process (possibly with new params). In this connection:
model_load_path
– path to a file for initializing model's weihts. This path should be passed to train scripttools/train.py
via--init_weights
parameter. If this path is not specified, CakeChatModel will construct it based onconfig.py
file and will try to initialize its weights from this file.model_save_path
– path to a file for saving model's weights. This path is constructed by CakeChatModel.Trained models are accessed via
get_trained_model()
and cached thanks to@cached
decoratorSeparate module for data_types – cakechat/utils/data_types.py
Following entities are incorporated into CakeChatModel:
is_reverse_model
flagparams_str
– string with main model parameterssave_model()
anddelete_model()
functionslog_predictions()
can simultaneously generate predictions for different PREDICTION_MODES, not just for oneAll datasets (train, train_subset, context_free_val, context_sensitive_val, context_sensitive_val_subset) are accessed via
get_datasets()
and passed via a namedTuple –DatasetsCollection
All statistics data about train process is collected and passed via a namedTuple
TrainStats
Logging logic is implemented in a dedicated function
_log_train_info_for_one_batch()
Analytics logic (metrics calculation) is implemented in a dedicated function
_analyse_model_performance()
Constant
MAX_VAL_LINES_NUM
defines max lines number of validation set to be used for metrics calculation – useful for testing and debugging purposes.Documentation is updated for various functions
Some typos are fixed