keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Update jax trainer function to save memory buffer. #897

Closed qlzh727 closed 10 months ago

qlzh727 commented 10 months ago

This is another attempt to save memory for training function. (In addition to https://github.com/keras-team/keras-core/pull/888)

With jax.jit(donate_arg), we can force jax to reuse the input arg memory buffer for the output, which will save one copy of memory size.

Since the donated memory can't be reused, I have to update the eval/test_on_batch function to use the output trainable_variables, since the original input copy has been donated.

The xporf for OPT model has already show some positive result: Before: https://xprof.corp.google.com/memory_profile/scottzhu-15065682269222877644 After: https://xprof.corp.google.com/memory_profile/scottzhu-6823282872073158074.

As you can see the heap allocation is greatly reduced.

codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.01% :tada:

Comparison is base (9d39e9a) 76.83% compared to head (4aecd5b) 76.85%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #897 +/- ## ========================================== + Coverage 76.83% 76.85% +0.01% ========================================== Files 329 329 Lines 31434 31435 +1 Branches 6114 6114 ========================================== + Hits 24151 24158 +7 + Misses 5719 5715 -4 + Partials 1564 1562 -2 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/897/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/897/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `76.75% <100.00%> (+0.01%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/897?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/jax/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/897?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC90cmFpbmVyLnB5) | `96.09% <100.00%> (+0.01%)` | :arrow_up: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/897/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

fchollet commented 10 months ago

Thanks for the PR! The code looks good.

Seems there's a CI failure though:

FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.

qlzh727 commented 10 months ago

Thanks for the PR! The code looks good.

Seems there's a CI failure though:

FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.

So this is bit complicated here, for the particular test that failed, there are variables appears in both non-trainable variables and metrics variables. This was causing issue when the memory get donated once first for non-trainable var, and failed when the same variable is accessed for metrics variable.

In general, I am not sure if we should make the non-trainable variable and metric variable to be mutually exclusive (which I think make much sense, eg the optimizer variable is not considered as non-trainable/trainable variable).

qlzh727 commented 10 months ago

Also, took a closer look to the trainer/layer code, seems that we have a non-trainable variable which contains metrics, and a non-trainable weights which doesn't contain metrics. Just curious why we have this differentiation?

I am not sure which approach is better and make more logical sense:

  1. Update trainer.non_trainable_variable to exclude anything from metrics_variables.
  2. Use the non_trainable_weights as the inputs to the model training function, this might be complicated due to the rng seed.
fchollet commented 10 months ago

"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage trainable_variables + non_trainable_variables as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.

We could certainly change that, and exclude metric variables. The fact that our JAX train_step separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).

The other route we could go is to actually double down on the idea that non_trainable_variables include metrics variables. In that case we'd try to stop special casing metric_variables in the JAX trainer.

Which is better?

qlzh727 commented 10 months ago

"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage trainable_variables + non_trainable_variables as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.

We could certainly change that, and exclude metric variables. The fact that our JAX train_step separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).

The other route we could go is to actually double down on the idea that non_trainable_variables include metrics variables. In that case we'd try to stop special casing metric_variables in the JAX trainer.

Which is better?

I see. I think it might make sense to exclude the metrics from non-trainable variable. My understanding of trainable/non-trainable variables are used for training/inference process, and affect the numerical output of the model. Eg the beta/gamma for BN is a good case for non-trainable variable. Same for the seed, which is used for generate the RNGs, either for initializer or dropout. The metrics variables on the other hand doesn't affect the training/inference outcome, even optimizer weights has a big contribution to the model output.

Also, just when I print out the state for the model under testing, the non-trainable variable doesn't even include all the metrics variables. It include the weights that are attached the model as metrics, but not for those under model.compile().

keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking 
print(model.non_trainable_variables)
[<KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]
print(model.metrics_variables)
[<KerasVariable shape=(), dtype=float32, path=loss/total>, <KerasVariable shape=(), dtype=float32, path=loss/count>, <KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]

Having said that, if we take this approach, then we won't have a easy way to visit those variable, unless we explicitly visit all the metrics for the layer and find those weights. Should we add a metrics_variables attribute for the layer?

fchollet commented 10 months ago

Ok, sounds good, let's keep both lists separate.

Should we add a metrics_variables attribute for the layer?

Yes, let's do that. It should be on the Layer class, but also overridden on the Trainer class to add compile metrics variables if the model is compiled.

qlzh727 commented 10 months ago

Ack. Done https://github.com/keras-team/keras-core/pull/910.

qlzh727 commented 10 months ago

Rebased this PR, and the unit test should pass now.