This PR updates the computation of the total number parameters in a (sharded) model.
Previously, the parameter count only reflected the number of parameters in a single shard wher the model was sharded.
General Changes
All ranks collectively reduce the total number of parameters.
Breaking Changes
None
Checklist before submitting final PR
[x] My PR is minimal and addresses one issue in isolation
[x] I have merged the latest version of the target branch into this feature branch
[x] I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
[x] I have run a sample config for model training
[x] I have checked that all tests run through (python tests/tests.py)
What does this PR do?
This PR updates the computation of the total number parameters in a (sharded) model. Previously, the parameter count only reflected the number of parameters in a single shard wher the model was sharded.
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py
)