mlcommons / algorithmic-efficiency

MLCommons Algorithmic Efficiency is a benchmark and competition measuring neural network training speedups due to algorithmic improvements in both training algorithms and models.
https://mlcommons.org/en/groups/research-algorithms/
Apache License 2.0
335 stars 69 forks source link

BatchNorm fixes for JAX and PyTorch workloads #798

Closed priyakasimbeg closed 3 weeks ago

priyakasimbeg commented 1 month ago

Fixes to BatchNorm behavior in JAX and PyTorch; mainly decouple update batch norm statistics from using the running statistics.

Changes for PyTorch from @adefazio's https://github.com/mlcommons/algorithmic-efficiency/pull/783

From pull/783: There are some subtle issues with how BatchNorm is handled in the PyTorch version of the code. Currently, workload.model_fn has an update_batch_norm parameter, which in theory should allow the submission to control whether the batch-norm statistics are updated during a forward pass. The issues are the following:

github-actions[bot] commented 1 month ago

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅