mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
319 stars 60 forks source link

Control over batch-norm running_mean/var buffers #767

Open adefazio opened 2 months ago

adefazio commented 2 months ago

Control over batch-norm running_mean/var buffers

Following up on the request in the recent working group meeting regarding future improvements to the challenge, it would be extremely useful if we had control over the running_mean/var buffers of batch-norm layers. Currently, if different iterates are used for evaluation and training (i.e. EMA or Schedule-Free averaging is used) then the running_mean/var values will be incorrect as they average over the training iterates.

This, together with the eval() support requested in #758 would make it much easier to implement averaging approaches.

In terms of control, it would be useful to turn on/off the updating of the running mean/var during forward passes, and to directly access their values. Currently there is a update_batch_norm switch that calls update_batch_norm_fn in pytorch_utils, but it doesn't allow us to update the batchnorm stats when in eval mode (eval mode changes the behavior of dropout, so we want to be in eval mode when updating BN statistics right before a model evaluation).

Also, having the batch-norm running-mean/var directly provided in the API would give a model-agnostic way to access them, currently we would need to loop over all modules and check if they are Pytorch Batch norm or the custom ConformerBatchNorm & DeepspeechBatchNorm layers.

A third point is rules clarity around batchnorm layers. Are we freely allowed to change the batch-norm momentum during training (which allows us to freeze the running stats, reset them, and otherwise change the speed they are updated), as well as the running-mean/var buffers?