keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62k stars 19.48k forks source link

model.keras format much slower to load #19793

Closed sirfz closed 5 months ago

sirfz commented 5 months ago

Anyone experiencing unreasonably slow load times when loading a keras-format saved model? I have noticed this repeated when working in ipython, where simply instantiating a model via Model.from_config then calling model.load_weights is much (several factors) faster than loading a model.keras file.

My understanding is the keras format is simply a zip file with the config.json file and weights h5 (iirc) but weirdly enough, there's something not right going on while loading.

sachinprasadhs commented 5 months ago

Could you please provide the comparison and the time difference in loading the model with .keras and and other format.

For more details on the changes included with .keras format and why it is preferred over other format, refer https://keras.io/guides/serialization_and_saving/

sirfz commented 5 months ago

I don't have an example at the moment but recently we updated our prod system from keras 2 to keras 3 so we converted all legacy saved models to the new keras 3 format which lead to our service to take over 12 minutes to load all models (>15 models loading in subprocesses in parallel). Moving to from_config + load_weights reduced the time to ~2 minutes (which is on par with what we had before).

For what it's worth, before we did that migration, I was already working on GPT2Backbone models with keras-nlp and noticed the same issue were loading the .keras model was really slow (but didn't give it much thought at the time)

fchollet commented 5 months ago

What you're using is actually the same as what load_model is using except for the interaction with the zip file. So perhaps the zip file reading is the issue.

sirfz commented 5 months ago

100% which is why I find this very odd

james77777778 commented 5 months ago

I encountered this issue before when trying to quantize Gemma

I have created this script to demonstrate the issue (using GPT-2)

check_loading.py

import argparse
import json

import keras
import keras_nlp

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--mode",
        default="save",
        choices=["save", "load", "load_weights"],
    )
    args = parser.parse_args()
    return args

def main(args):
    if args.mode == "save":
        model = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
        # Save keras file
        model.save("model.keras")
        # Save serialized config and weights
        config = keras.saving.serialize_keras_object(model)
        with open("model.json", "w") as f:
            json.dump(config, f)
        model.save_weights("model.weights.h5")
    elif args.mode == "load":
        model = keras.saving.load_model("model.keras")
    else:
        with open("model.json", "r") as f:
            config = json.load(f)
        model = keras.saving.deserialize_keras_object(config)
        model.load_weights("model.weights.h5")

if __name__ == "__main__":
    keras.config.disable_traceback_filtering()
    main(get_args())

Usage:

# 1. Save the model
python check_loading.py -m save
# 2. Profile `load_model`
pyinstrument python check_loading.py -m load
# 3. Profile `deserialize_keras_object` and `load_weights`
pyinstrument python check_loading.py -m load_weights

The result:

Method Cost Time
load_model 27.861s
deserialize_keras_object + load_weights 3.166s

Logs:

```console _ ._ __/__ _ _ _ _ _/_ Recorded: 10:05:02 Samples: 10954 /_//_/// /_\ / //_// / //_'/ // Duration: 27.861 CPU time: 30.009 / _/ v4.6.2 Program: /home/hongyu/miniconda3/envs/kimm/bin/pyinstrument check_loading.py -m load 27.861 check_loading.py:1 ├─ 25.635 main check_loading.py:20 │ └─ 25.635 load_model keras/src/saving/saving_api.py:116 │ └─ 25.634 load_model keras/src/saving/saving_lib.py:138 │ └─ 25.634 _load_model_from_fileobj keras/src/saving/saving_lib.py:157 │ ├─ 24.319 _load_state keras/src/saving/saving_lib.py:395 │ │ ├─ 23.507 _load_container_state keras/src/saving/saving_lib.py:510 │ │ │ └─ 23.507 _load_state keras/src/saving/saving_lib.py:395 │ │ │ └─ 23.505 _load_container_state keras/src/saving/saving_lib.py:510 │ │ │ └─ 23.504 _load_state keras/src/saving/saving_lib.py:395 │ │ │ ├─ 21.286 _load_state keras/src/saving/saving_lib.py:395 │ │ │ │ ├─ 9.102 _load_state keras/src/saving/saving_lib.py:395 │ │ │ │ │ ├─ 5.381 H5IOStore.get keras/src/saving/saving_lib.py:632 │ │ │ │ │ │ └─ 5.381 H5Entry.__init__ keras/src/saving/saving_lib.py:646 │ │ │ │ │ │ ├─ 3.618 Group.__contains__ h5py/_hl/group.py:508 │ │ │ │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ │ │ │ └─ 1.763 File.__getitem__ h5py/_hl/group.py:348 │ │ │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ │ │ └─ 3.717 EinsumDense.load_own_variables keras/src/layers/core/einsum_dense.py:279 │ │ │ │ │ └─ 3.579 H5Entry.__getitem__ keras/src/saving/saving_lib.py:702 │ │ │ │ │ └─ 3.577 Group.__getitem__ h5py/_hl/group.py:348 │ │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ │ ├─ 7.054 H5IOStore.get keras/src/saving/saving_lib.py:632 │ │ │ │ │ └─ 7.054 H5Entry.__init__ keras/src/saving/saving_lib.py:646 │ │ │ │ │ ├─ 4.377 Group.__contains__ h5py/_hl/group.py:508 │ │ │ │ │ │ [9 frames hidden] h5py, zipfile, │ │ │ │ │ └─ 2.677 Group.__getitem__ h5py/_hl/group.py:348 │ │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ │ ├─ 3.121 LayerNormalization.load_own_variables keras/src/layers/layer.py:1187 │ │ │ │ │ ├─ 1.936 H5Entry.__getitem__ keras/src/saving/saving_lib.py:702 │ │ │ │ │ │ └─ 1.935 Group.__getitem__ h5py/_hl/group.py:348 │ │ │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ │ │ └─ 0.967 Variable.assign keras/src/backend/common/variables.py:223 │ │ │ │ │ └─ 0.962 Variable._convert_to_tensor keras/src/backend/tensorflow/core.py:53 │ │ │ │ │ └─ 0.962 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ │ │ │ │ └─ 0.961 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ │ │ │ [18 frames hidden] tensorflow, h5py, zipfile, │ │ │ │ └─ 1.978 Dense.load_own_variables keras/src/layers/core/dense.py:224 │ │ │ │ └─ 1.690 H5Entry.__getitem__ keras/src/saving/saving_lib.py:702 │ │ │ │ └─ 1.690 Group.__getitem__ h5py/_hl/group.py:348 │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ ├─ 1.576 H5IOStore.get keras/src/saving/saving_lib.py:632 │ │ │ │ └─ 1.576 H5Entry.__init__ keras/src/saving/saving_lib.py:646 │ │ │ │ └─ 1.391 Group.__contains__ h5py/_hl/group.py:508 │ │ │ │ [6 frames hidden] h5py, zipfile, │ │ │ ├─ 0.344 ReversibleEmbedding.load_own_variables keras_nlp/src/layers/modeling/reversible_embedding.py:151 │ │ │ │ └─ 0.344 ReversibleEmbedding.load_own_variables keras/src/layers/core/embedding.py:214 │ │ │ │ └─ 0.288 Variable.assign keras/src/backend/common/variables.py:223 │ │ │ │ └─ 0.288 Variable._convert_to_tensor keras/src/backend/tensorflow/core.py:53 │ │ │ │ └─ 0.288 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ │ │ │ └─ 0.288 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ │ │ [11 frames hidden] tensorflow │ │ │ └─ 0.298 TransformerDecoder.load_own_variables keras/src/layers/layer.py:1187 │ │ └─ 0.809 _load_state keras/src/saving/saving_lib.py:395 │ │ └─ 0.586 _load_state keras/src/saving/saving_lib.py:395 │ │ └─ 0.467 GPT2Tokenizer.load_assets keras_nlp/src/tokenizers/byte_pair_tokenizer.py:327 │ │ [3 frames hidden] keras_nlp │ ├─ 0.534 deserialize_keras_object keras/src/saving/serialization_lib.py:393 │ │ └─ 0.534 GPT2CausalLM.from_config keras_nlp/src/models/task.py:143 │ │ └─ 0.522 deserialize keras/src/layers/__init__.py:153 │ │ └─ 0.522 deserialize_keras_object keras/src/saving/serialization_lib.py:393 │ │ └─ 0.516 GPT2Backbone.from_config keras_nlp/src/models/backbone.py:139 │ │ [3 frames hidden] keras_nlp │ │ 0.293 TransformerDecoder.__call__ keras_nlp/src/layers/modeling/transformer_decoder.py:253 │ │ └─ 0.288 build_wrapper keras/src/layers/layer.py:220 │ │ └─ 0.288 TransformerDecoder.build keras_nlp/src/layers/modeling/transformer_decoder.py:134 │ └─ 0.458 DiskIOStore.__init__ keras/src/saving/saving_lib.py:556 │ └─ 0.458 ZipFile.extractall zipfile.py:1666 │ [4 frames hidden] zipfile, shutil, └─ 2.121 keras/__init__.py:1 └─ 2.121 keras/api/__init__.py:1 └─ 2.120 keras/api/_tf_keras/__init__.py:1 └─ 2.120 keras/api/_tf_keras/keras/__init__.py:1 └─ 2.110 keras/api/activations/__init__.py:1 └─ 2.110 keras/src/__init__.py:1 └─ 2.109 keras/src/activations/__init__.py:1 └─ 1.921 keras/src/activations/activations.py:1 └─ 1.917 keras/src/backend/__init__.py:1 └─ 1.824 keras/src/backend/common/__init__.py:1 └─ 1.823 keras/src/backend/common/dtypes.py:1 └─ 1.823 keras/src/backend/common/variables.py:1 └─ 1.758 keras/src/utils/__init__.py:1 └─ 1.721 keras/src/utils/model_visualization.py:1 └─ 1.705 keras/src/tree/__init__.py:1 └─ 1.705 keras/src/tree/tree_api.py:1 └─ 1.699 keras/src/tree/optree_impl.py:1 └─ 1.698 tensorflow/__init__.py:1 [23 frames hidden] tensorflow _ ._ __/__ _ _ _ _ _/_ Recorded: 10:05:39 Samples: 2266 /_//_/// /_\ / //_// / //_'/ // Duration: 3.166 CPU time: 5.276 / _/ v4.6.2 Program: /home/hongyu/miniconda3/envs/kimm/bin/pyinstrument check_loading.py -m load_weights 3.165 check_loading.py:1 ├─ 2.121 keras/__init__.py:1 │ └─ 2.121 keras/api/__init__.py:1 │ └─ 2.120 keras/api/_tf_keras/__init__.py:1 │ └─ 2.120 keras/api/_tf_keras/keras/__init__.py:1 │ └─ 2.110 keras/api/activations/__init__.py:1 │ └─ 2.110 keras/src/__init__.py:1 │ └─ 2.109 keras/src/activations/__init__.py:1 │ ├─ 1.922 keras/src/activations/activations.py:1 │ │ └─ 1.917 keras/src/backend/__init__.py:1 │ │ ├─ 1.825 keras/src/backend/common/__init__.py:1 │ │ │ └─ 1.824 keras/src/backend/common/dtypes.py:1 │ │ │ └─ 1.824 keras/src/backend/common/variables.py:1 │ │ │ ├─ 1.760 keras/src/utils/__init__.py:1 │ │ │ │ └─ 1.722 keras/src/utils/model_visualization.py:1 │ │ │ │ └─ 1.707 keras/src/tree/__init__.py:1 │ │ │ │ └─ 1.707 keras/src/tree/tree_api.py:1 │ │ │ │ └─ 1.701 keras/src/tree/optree_impl.py:1 │ │ │ │ └─ 1.701 tensorflow/__init__.py:1 │ │ │ │ [116 frames hidden] tensorflow, , inspect, requ... │ │ │ └─ 0.063 numpy/__init__.py:1 │ │ └─ 0.091 keras/src/backend/tensorflow/__init__.py:1 │ │ └─ 0.089 keras/src/backend/tensorflow/numpy.py:1 │ │ └─ 0.089 elementwise_unary keras/src/backend/tensorflow/sparse.py:348 │ │ └─ 0.089 update_wrapper functools.py:35 │ └─ 0.186 keras/src/saving/__init__.py:1 │ └─ 0.186 keras/src/saving/saving_api.py:1 │ └─ 0.186 keras/src/legacy/saving/legacy_h5_format.py:1 │ └─ 0.182 keras/src/legacy/saving/saving_utils.py:1 │ ├─ 0.139 keras/src/models/__init__.py:1 │ │ └─ 0.138 keras/src/models/functional.py:1 │ │ └─ 0.138 keras/src/models/model.py:1 │ │ └─ 0.135 keras/src/trainers/trainer.py:1 │ │ └─ 0.135 keras/src/trainers/data_adapters/__init__.py:1 │ │ └─ 0.134 keras/src/trainers/data_adapters/array_data_adapter.py:1 │ │ └─ 0.133 keras/src/trainers/data_adapters/array_slicing.py:1 │ │ └─ 0.133 pandas/__init__.py:1 │ │ [5 frames hidden] pandas │ └─ 0.043 keras/src/layers/__init__.py:1 ├─ 0.940 main check_loading.py:20 │ ├─ 0.540 deserialize_keras_object keras/src/saving/serialization_lib.py:393 │ │ └─ 0.539 GPT2CausalLM.from_config keras_nlp/src/models/task.py:143 │ │ └─ 0.528 deserialize keras/src/layers/__init__.py:153 │ │ └─ 0.528 deserialize_keras_object keras/src/saving/serialization_lib.py:393 │ │ └─ 0.522 GPT2Backbone.from_config keras_nlp/src/models/backbone.py:139 │ │ [3 frames hidden] keras_nlp │ │ 0.522 GPT2Backbone.__init__ keras_nlp/src/models/gpt2/gpt2_backbone.py:92 │ │ ├─ 0.294 TransformerDecoder.__call__ keras_nlp/src/layers/modeling/transformer_decoder.py:253 │ │ │ └─ 0.289 build_wrapper keras/src/layers/layer.py:220 │ │ │ └─ 0.289 TransformerDecoder.build keras_nlp/src/layers/modeling/transformer_decoder.py:134 │ │ │ └─ 0.223 build_wrapper keras/src/layers/layer.py:220 │ │ │ ├─ 0.138 CachedMultiHeadAttention.build keras/src/layers/attention/multi_head_attention.py:199 │ │ │ │ └─ 0.091 build_wrapper keras/src/layers/layer.py:220 │ │ │ │ └─ 0.088 EinsumDense.build keras/src/layers/core/einsum_dense.py:147 │ │ │ │ └─ 0.082 EinsumDense.add_weight keras/src/layers/layer.py:455 │ │ │ │ └─ 0.080 Variable.__init__ keras/src/backend/common/variables.py:80 │ │ │ │ └─ 0.046 Variable._initialize keras/src/backend/tensorflow/core.py:30 │ │ │ │ └─ 0.045 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ │ │ [14 frames hidden] tensorflow, │ │ │ ├─ 0.046 Dense.build keras/src/layers/core/dense.py:102 │ │ │ │ └─ 0.045 Dense.add_weight keras/src/layers/layer.py:455 │ │ │ │ └─ 0.045 Variable.__init__ keras/src/backend/common/variables.py:80 │ │ │ └─ 0.033 LayerNormalization.build keras/src/layers/normalization/layer_normalization.py:147 │ │ │ └─ 0.033 LayerNormalization.add_weight keras/src/layers/layer.py:455 │ │ │ └─ 0.033 Variable.__init__ keras/src/backend/common/variables.py:80 │ │ └─ 0.199 Dropout.__init__ keras/src/layers/regularization/dropout.py:41 │ │ └─ 0.199 SeedGenerator.__init__ keras/src/random/seed_generator.py:48 │ │ └─ 0.199 Variable.__init__ keras/src/backend/common/variables.py:80 │ │ └─ 0.198 seed_initializer keras/src/random/seed_generator.py:70 │ │ └─ 0.198 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ │ └─ 0.198 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ [17 frames hidden] tensorflow, │ └─ 0.399 error_handler keras/src/utils/traceback_utils.py:110 │ └─ 0.399 GPT2CausalLM.load_weights keras/src/models/model.py:321 │ └─ 0.399 load_weights keras/src/saving/saving_api.py:226 │ └─ 0.399 load_weights_only keras/src/saving/saving_lib.py:239 │ └─ 0.399 _load_state keras/src/saving/saving_lib.py:395 │ └─ 0.392 _load_container_state keras/src/saving/saving_lib.py:510 │ └─ 0.392 _load_state keras/src/saving/saving_lib.py:395 │ └─ 0.391 _load_container_state keras/src/saving/saving_lib.py:510 │ └─ 0.390 _load_state keras/src/saving/saving_lib.py:395 │ ├─ 0.209 _load_state keras/src/saving/saving_lib.py:395 │ │ ├─ 0.088 Dense.load_own_variables keras/src/layers/core/dense.py:224 │ │ │ └─ 0.086 Variable.assign keras/src/backend/common/variables.py:223 │ │ │ └─ 0.086 Variable._convert_to_tensor keras/src/backend/tensorflow/core.py:53 │ │ │ └─ 0.086 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ │ │ └─ 0.085 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ │ [14 frames hidden] tensorflow, h5py │ │ └─ 0.077 _load_state keras/src/saving/saving_lib.py:395 │ │ └─ 0.060 EinsumDense.load_own_variables keras/src/layers/core/einsum_dense.py:279 │ │ └─ 0.059 Variable.assign keras/src/backend/common/variables.py:223 │ │ └─ 0.046 Variable._convert_to_tensor keras/src/backend/tensorflow/core.py:53 │ │ └─ 0.046 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ │ └─ 0.046 error_handler tensorflow/python/util/traceback_utils.py:138 │ │ [11 frames hidden] tensorflow │ └─ 0.172 ReversibleEmbedding.load_own_variables keras_nlp/src/layers/modeling/reversible_embedding.py:151 │ └─ 0.172 ReversibleEmbedding.load_own_variables keras/src/layers/core/embedding.py:214 │ └─ 0.164 Variable.assign keras/src/backend/common/variables.py:223 │ └─ 0.164 Variable._convert_to_tensor keras/src/backend/tensorflow/core.py:53 │ └─ 0.164 convert_to_tensor keras/src/backend/tensorflow/core.py:102 │ └─ 0.164 error_handler tensorflow/python/util/traceback_utils.py:138 │ [14 frames hidden] tensorflow, h5py └─ 0.102 keras_nlp/__init__.py:1 [18 frames hidden] keras_nlp, tensorflow_text ```
Grvzard commented 5 months ago

By diving into the example provided by @james77777778 , in the hidden frames, there's a call: Group.__getitem__ -> ZipExtFile.seek This makes sense when we are using archive.

in python stdlib zipfile.ZipExtFile: seek -> read -> _read1 -> _update_crc The overhead caused by _update_crc during each seek() call is significant. reference: https://github.com/python/cpython/blob/f878d46e5614f08a9302fcb6fc611ef49e9acf2f/Lib/zipfile/__init__.py#L1133

Grvzard commented 5 months ago

A simple way to deal with it, which will work fine:

https://github.com/keras-team/keras/blob/a2df0f9ac595639aa2d3a0359122b030d934389e/keras/src/saving/saving_lib.py#L620-L627

by changing line 624 to self.io_file = io.BytesIO(self.archive.open(self.root_path, "r").read())

sirfz commented 5 months ago

That probably fixes the speed issue but would lead to unwanted extra memory usage which is undesirable

fchollet commented 5 months ago

That probably fixes the speed issue but would lead to unwanted extra memory usage which is undesirable

Is that a good tradeoff? Should we instead unzip on disk then load from the h5 file? What do you think @james77777778 @Grvzard ?

james77777778 commented 5 months ago

Is that a good tradeoff?

Generally, It should be okay to load the entire h5 into memory before loading. This is the case when saving:

  1. write into memory first: https://github.com/keras-team/keras/blob/77ef1792201cd30b52146c71dd0380786512ac84/keras/src/saving/saving_lib.py#L620-L622
  2. write to disk when closing: https://github.com/keras-team/keras/blob/77ef1792201cd30b52146c71dd0380786512ac84/keras/src/saving/saving_lib.py#L635-L638

We can also provide an option to let users decide whether to use a faster but more memory-intensive approach.

Should we instead unzip on disk then load from the h5 file?

Actually, h5py doesn't recommend using file-like object. https://docs.h5py.org/en/stable/high/file.html#python-file-like-objects So, unzipping and then loading from the H5 file might be a better approach, IMO.

Grvzard commented 5 months ago

So, unzipping and then loading from the H5 file might be a better approach

Same.

google-ml-butler[bot] commented 5 months ago

Are you satisfied with the resolution of your issue? Yes No