keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Memory optimization for jax trainer. #888

Closed qlzh727 closed 11 months ago

qlzh727 commented 11 months ago

Update the JAX trainer to use less memory during fit/eval.

Currently we keep 3 copies of all the model variable state during training:

  1. jax.array attached to the KerasVariable
  2. jax.array as input to the jax.jit() train/eval function
  3. jax.array returned from jax.jit() train/eval function, which is also attached to the trainer.jax_state.

the 2 and 3 will keep getting update during the training process, but 1 will be a stale copy and unnecessarily occupying the heap size. In the large model case, this will be huge.

This PR will purge the 1, and it will restore the KerasVariable at the end of the epoch. This save the memory size by 33%. From some early test result internally, the per device memory usage is reduce from 9.49G to 6.65G for a OPT2 model.

I will send follow up PR to address additional memory usage for eval and predict functions.

codecov[bot] commented 11 months ago

Codecov Report

Patch coverage: 12.50% and project coverage change: -0.06% :warning:

Comparison is base (1704ecf) 79.75% compared to head (3fd4ad4) 79.70%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #888 +/- ## ========================================== - Coverage 79.75% 79.70% -0.06% ========================================== Files 318 318 Lines 28638 28657 +19 Branches 5451 5460 +9 ========================================== Hits 22841 22841 - Misses 4333 4351 +18 - Partials 1464 1465 +1 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/888/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/888/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `79.62% <12.50%> (-0.06%)` | :arrow_down: | | [keras_core-numpy](https://app.codecov.io/gh/keras-team/keras-core/pull/888/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.37% <8.33%> (-0.04%)` | :arrow_down: | | [keras_core-tensorflow](https://app.codecov.io/gh/keras-team/keras-core/pull/888/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.78% <12.50%> (-0.05%)` | :arrow_down: | | [keras_core-torch](https://app.codecov.io/gh/keras-team/keras-core/pull/888/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `69.21% <12.50%> (-0.05%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/888?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/jax/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/888?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC90cmFpbmVyLnB5) | `0.00% <0.00%> (ø)` | | | [keras\_core/layers/layer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/888?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvbGF5ZXIucHk=) | `86.64% <0.00%> (-0.28%)` | :arrow_down: | | [keras\_core/backend/common/variables.py](https://app.codecov.io/gh/keras-team/keras-core/pull/888?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2NvbW1vbi92YXJpYWJsZXMucHk=) | `75.42% <75.00%> (ø)` | |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

qlzh727 commented 11 months ago

Let's park this PR for now, since there are some underlying issues (discovered in the unit test) need to be address.

qlzh727 commented 11 months ago

ok. I think i figure out the issue.

  1. The regularizer is actually a bug in the existing code, where a stale version of variable value is used. The layers.py is updated to retrieve the latest value if it is in a stateless scope.
  2. The RNN failed since it was trying to create a new stateless scope under the hood, which doesn't have any variable mapping. I added a conditional creation of the Stateless scope to fix the issue.
qlzh727 commented 11 months ago

Humm, its more tricky than I expected, esp for RNN. In the RNN (LSTM/GRU) with dropout, the seed generator get updated within the jax.lax.scan() in https://github.com/keras-team/keras-core/blob/b4019bc1bfa2c2f15decb533d054cbcf31562124/keras_core/backend/jax/rnn.py#L188 (we even have a comment about the function need to be stateless). Since the step function doesn't return the RNG state as explicit result, when we captured the updated RNG from train function and reuse it for jax constraint, JAX complain about this leaked state from the scan function.

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

Will take a closer look tomorrow.

fchollet commented 11 months ago

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

We previously observed some weirdness in JAX RNNs (though not connected to dropout it seems). Maybe some relationship there? https://github.com/keras-team/keras-core/issues/322

fchollet commented 11 months ago

Since the step function doesn't return the RNG state as explicit result, when we captured the updated RNG from train function and reuse it for jax constraint, JAX complain about this leaked state from the scan function.

I think we might want to add a stateless version of the step function / etc.

qlzh727 commented 11 months ago

I think the current repo is probably not working correctly, since the StatelessScope for the scan function was throw away, which mean the RNG seed update is lost.

We previously observed some weirdness in JAX RNNs (though not connected to dropout it seems). Maybe some relationship there? #322

The particular issue we hit is for the RNG state, which gets updated when using dropout. It runs fine if I disable the dropout for the layer. Need to dig a bit more.

qlzh727 commented 11 months ago

Having said that, I think my PR will actually fix the issue in https://github.com/keras-team/keras-core/issues/322. Due to the stateless scope in the jax.lax.scan function, it will only read the staled variable value, which is probably why it is not trained properly.

qlzh727 commented 11 months ago

So the stateless RNN test failure should be addressed by https://github.com/keras-team/keras-core/pull/924, and the test for this PR should pass now.