keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Add a loss scale optimizer #851

Closed mattdangerw closed 12 months ago

mattdangerw commented 1 year ago

This is the big missing piece we need for feature parity when running mixed precision training compared to tf.keras.

Fixes #571

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 86.30% and project coverage change: +0.09% :tada:

Comparison is base (ab45558) 75.99% compared to head (fd58cfb) 76.09%. Report is 6 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #851 +/- ## ========================================== + Coverage 75.99% 76.09% +0.09% ========================================== Files 328 329 +1 Lines 31099 31269 +170 Branches 6051 6083 +32 ========================================== + Hits 23635 23793 +158 - Misses 5866 5874 +8 - Partials 1598 1602 +4 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/851/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/851/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `75.99% <86.30%> (+0.08%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/tensorflow/optimizer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvb3B0aW1pemVyLnB5) | `90.56% <ø> (-0.35%)` | :arrow_down: | | [keras\_core/callbacks/tensorboard.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9jYWxsYmFja3MvdGVuc29yYm9hcmQucHk=) | `83.84% <ø> (+0.11%)` | :arrow_up: | | [keras\_core/ops/core.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHMvY29yZS5weQ==) | `74.09% <0.00%> (-4.05%)` | :arrow_down: | | [keras\_core/backend/torch/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RvcmNoL3RyYWluZXIucHk=) | `89.56% <50.00%> (-0.35%)` | :arrow_down: | | [keras\_core/backend/tensorflow/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL3RlbnNvcmZsb3cvdHJhaW5lci5weQ==) | `78.53% <66.66%> (-0.14%)` | :arrow_down: | | [keras\_core/optimizers/base\_optimizer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHRpbWl6ZXJzL2Jhc2Vfb3B0aW1pemVyLnB5) | `74.76% <90.00%> (+0.56%)` | :arrow_up: | | [keras\_core/optimizers/loss\_scale\_optimizer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHRpbWl6ZXJzL2xvc3Nfc2NhbGVfb3B0aW1pemVyLnB5) | `94.28% <94.28%> (ø)` | | | [keras\_core/backend/jax/numpy.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC9udW1weS5weQ==) | `97.69% <100.00%> (+0.01%)` | :arrow_up: | | [keras\_core/backend/jax/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC90cmFpbmVyLnB5) | `96.08% <100.00%> (+0.11%)` | :arrow_up: | | [keras\_core/optimizers/\_\_init\_\_.py](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9vcHRpbWl6ZXJzL19faW5pdF9fLnB5) | `92.10% <100.00%> (+0.21%)` | :arrow_up: | | ... and [1 more](https://app.codecov.io/gh/keras-team/keras-core/pull/851?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | | ... and [5 files with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/851/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

mattdangerw commented 1 year ago

A few points of awkwardness/discussion:

  1. ops.cond needs to be stateless for jax. Autoscaling has two cond branches with variables updates in all branches. To do this I overrode stateless_apply separately and had to do a lot of StatelessScopes. This feels very verbose and awkward, but I wasn't able to think of a great way around it. Suggestions welcome!
  2. We want learning_rate to proxy the inner optimizer learning rate. I overrode the learning_rate property to do so. But the base optimizer still has to be created with a learning rate, so I ended up just passing a zero valued variable which never gets used. This feels awkward and a bit confusing.
  3. Because the loss scale optimizer needs a variable to scale the loss, and we don't sync jax state except on epoch boundaries, I made a scale_loss and stateless_scale_loss. This is probably fine, just feels like a lot of code.
mattdangerw commented 1 year ago

Looks like there is some sort of device placement issue for tensorflow GPU that isn't picked up in our CPU only tests? Will poke around. Confirmed torch/jax are working.

[Edit: now working on all backends]

mattdangerw commented 1 year ago

Addressed the initial round, though I may play with a test that deliberately triggers the underflow in trainer and asserts that variable updates appear. I don't think that should be too hard? But we will see.

mattdangerw commented 12 months ago

Added a test end to end test only looking at variables updates across fit(), which will hopefully keep us from accidentally breaking this.