Closed qlzh727 closed 10 months ago
Patch coverage: 100.00%
and project coverage change: +0.01%
:tada:
Comparison is base (
9d39e9a
) 76.83% compared to head (4aecd5b
) 76.85%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thanks for the PR! The code looks good.
Seems there's a CI failure though:
FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.
Thanks for the PR! The code looks good.
Seems there's a CI failure though:
FAILED keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking - jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Invalid buffer passed: buffer has been deleted or donated.
So this is bit complicated here, for the particular test that failed, there are variables appears in both non-trainable variables and metrics variables. This was causing issue when the memory get donated once first for non-trainable var, and failed when the same variable is accessed for metrics variable.
In general, I am not sure if we should make the non-trainable variable and metric variable to be mutually exclusive (which I think make much sense, eg the optimizer variable is not considered as non-trainable/trainable variable).
Also, took a closer look to the trainer/layer code, seems that we have a non-trainable variable which contains metrics, and a non-trainable weights which doesn't contain metrics. Just curious why we have this differentiation?
I am not sure which approach is better and make more logical sense:
"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage trainable_variables
+ non_trainable_variables
as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.
We could certainly change that, and exclude metric variables. The fact that our JAX train_step
separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).
The other route we could go is to actually double down on the idea that non_trainable_variables
include metrics variables. In that case we'd try to stop special casing metric_variables
in the JAX trainer.
Which is better?
"non trainable variables include metric variables" is something we say in a bunch of places. The benefit of that is that you only have to manage
trainable_variables
+non_trainable_variables
as being the whole state of the model, included any metrics attached to it, plus random seeds, etc.We could certainly change that, and exclude metric variables. The fact that our JAX
train_step
separates both is a point in that direction. Another point: metric variables are not included in saving (this is because their state is not useful to keep across reloads).The other route we could go is to actually double down on the idea that
non_trainable_variables
include metrics variables. In that case we'd try to stop special casingmetric_variables
in the JAX trainer.Which is better?
I see. I think it might make sense to exclude the metrics from non-trainable variable. My understanding of trainable/non-trainable variables are used for training/inference process, and affect the numerical output of the model. Eg the beta/gamma for BN is a good case for non-trainable variable. Same for the seed, which is used for generate the RNGs, either for initializer or dropout. The metrics variables on the other hand doesn't affect the training/inference outcome, even optimizer weights has a big contribution to the model output.
Also, just when I print out the state for the model under testing, the non-trainable variable doesn't even include all the metrics variables. It include the weights that are attached the model as metrics, but not for those under model.compile().
keras_core/trainers/trainer_test.py::TestTrainer::test_metric_tracking
print(model.non_trainable_variables)
[<KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]
print(model.metrics_variables)
[<KerasVariable shape=(), dtype=float32, path=loss/total>, <KerasVariable shape=(), dtype=float32, path=loss/count>, <KerasVariable shape=(), dtype=float32, path=my_metric/total>, <KerasVariable shape=(), dtype=float32, path=my_metric/count>]
Having said that, if we take this approach, then we won't have a easy way to visit those variable, unless we explicitly visit all the metrics for the layer and find those weights. Should we add a metrics_variables
attribute for the layer?
Ok, sounds good, let's keep both lists separate.
Should we add a metrics_variables attribute for the layer?
Yes, let's do that. It should be on the Layer
class, but also overridden on the Trainer
class to add compile metrics variables if the model is compiled.
Rebased this PR, and the unit test should pass now.
This is another attempt to save memory for training function. (In addition to https://github.com/keras-team/keras-core/pull/888)
With jax.jit(donate_arg), we can force jax to reuse the input arg memory buffer for the output, which will save one copy of memory size.
Since the donated memory can't be reused, I have to update the eval/test_on_batch function to use the output trainable_variables, since the original input copy has been donated.
The xporf for OPT model has already show some positive result: Before: https://xprof.corp.google.com/memory_profile/scottzhu-15065682269222877644 After: https://xprof.corp.google.com/memory_profile/scottzhu-6823282872073158074.
As you can see the heap allocation is greatly reduced.