neulab / xnmt

eXtensible Neural Machine Translation
Other
185 stars 44 forks source link

Multi-task learning and checkpoint saving #395

Open mbollmann opened 6 years ago

mbollmann commented 6 years ago

I am trying to train a model with a relatively large number of auxiliary tasks (~30), which runs fine in terms of training the network, but is ultimately impractical due to excessive checkpoint saving.

  1. When using a multi-task training regimen, the save function (save_fct) is potentially called once for each task, even though it is not task-dependent.

    For example: https://github.com/neulab/xnmt/blob/master/xnmt/training_regimen.py#L237

    If I see this correctly, for a model consisting of n training tasks, the identical model state is saved up to n times in a row, wasting computation time.

  2. In a multi-task training regimen, the model saving seems to be triggered whenever any of the tasks completes an epoch.

    This is because TrainingTask decides that saving is always needed when there are no dev tasks: https://github.com/neulab/xnmt/blob/master/xnmt/training_task.py#L339

    However, in a MTL scenario, "no dev tasks" can mean that I'm simply not interested in evaluating this particular training task, and it should indeed never be cause for checkpoint saving. I don't see any way to achieve this behavior right now.

mbollmann commented 6 years ago

Idea: No. 2 could probably be achieved by defining a new AuxiliaryTrainingTask which ignores checkpoints, and could be used whenever this particular behavior is desired.

msperber commented 6 years ago

Yeah that's true, the case of no dev tasks is currently not handled ideally. I would prefer if we could have the training regimen be in charge of when stuff gets saved. The training tasks should only give a hint to the regimen when a new best score was achieved. Probably, this would amount to:

If deviations from this are desired that could be achieved by configuring the training regimen accordingly, although it seems to me that this default behavior would be reasonable in most cases.

Necessary changes might include dividing training_task.checkpoint(control_learning_schedule) into two methods, e.g. training_task.checkpoint() and training_task.control_learning_schedule(), which is probably the cleaner solution anyways.