keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Add loss scaling technique to `BaseOptimizer` #842

Closed james77777778 closed 1 year ago

james77777778 commented 1 year ago

Related to #571

I'm not sure if anyone is addressing this issue, or if it should not be handled externally.

I'm simply trying to implement a backend-agnostic solution. Currently, only tf trainer has the ability to automatically enable the loss scaling when using mixed_float16 and mixed_bfloat16. It should be straightforward to extend this to other backends.

In the colab, the training works well. However, it will fail if the loss scaling technique is not applied. Colab: https://colab.research.google.com/drive/1T8DpRXu7CI52-67Poq5kKFvEM1P3UwT-?usp=sharing

cc @mattdangerw

EDITED: Now, this PR should work with tf and torch backend. However, jax does not support yet.

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 56.00% and project coverage change: +5.03% :tada:

Comparison is base (49e5b06) 70.90% compared to head (609d711) 75.94%. Report is 3 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #842 +/- ## ========================================== + Coverage 70.90% 75.94% +5.03% ========================================== Files 344 328 -16 Lines 33300 31166 -2134 Branches 6409 6063 -346 ========================================== + Hits 23612 23669 +57 + Misses 8093 5894 -2199 - Partials 1595 1603 +8 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/842/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/842/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `75.85% <56.00%> (?)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/842?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/optimizers/base\_optimizer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/842?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHRpbWl6ZXJzL2Jhc2Vfb3B0aW1pemVyLnB5) | `69.78% <50.81%> (-4.42%)` | :arrow_down: | | [keras\_core/trainers/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/842?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS90cmFpbmVycy90cmFpbmVyLnB5) | `84.07% <72.72%> (-0.66%)` | :arrow_down: | | [keras\_core/backend/tensorflow/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/842?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvdHJhaW5lci5weQ==) | `78.72% <100.00%> (+0.05%)` | :arrow_up: | | [keras\_core/backend/torch/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/842?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL3RyYWluZXIucHk=) | `89.95% <100.00%> (+0.04%)` | :arrow_up: | ... and [28 files with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/842/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

fchollet commented 1 year ago

Thanks for the PR! @mattdangerw was working on this, so I'll let him review.

james77777778 commented 1 year ago

Thanks for the PR! @mattdangerw was working on this, so I'll let him review.

Sure.

Here are some additional notes:

Loss scaling is crucial for mixed precision even in a small model. You can observe the difference when loss scaling is applied versus when it is not.

MNIST classification training with mixed_float16:

with loss scaling with_loss_scaling

without loss scaling without_loss_scaling

Updated colab: https://colab.research.google.com/drive/1T8DpRXu7CI52-67Poq5kKFvEM1P3UwT-?usp=sharing

mattdangerw commented 1 year ago

Ah shoot! Yes I started working on this last week, and think I have this working on all backends.

On the discussion thread, we decided to try adding this as a separate optimizer for the compatibility with tf.keras. If we end up going with that (separate optimizer) approach, we may need to close this. See https://github.com/keras-team/keras-core/pull/851

mattdangerw commented 1 year ago

Also one random note, I believe bfloat16 does not actually need loss scaling, as it has the same amount of exponent bits as float32. Loss scaling is just for mixed_float16!

https://www.tensorflow.org/guide/mixed_precision#loss_scaling

james77777778 commented 1 year ago

Thanks @mattdangerw

I was attempting to train a large model that is challenging to fit into a 24GB memory with float32. So I implemented my own solution. It's good to know that we will have an official solution for loss scaling.

This PR should be closed.