keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.51k stars 19.41k forks source link

TorchModuleWrapper default `get_config` creates non serializable objects #19226

Open gilfree opened 6 months ago

gilfree commented 6 months ago

Hi, and thanks to all keras-team :)

The get_config of TorchModuleWrapper uses BytesIO and torch.save to create a byte array:

https://github.com/keras-team/keras/blob/1137074a9a1c237473f2fe57ab277c697892c6f1/keras/utils/torch_utils.py#L140

In the serialization lib, bytes object are then decoded to utf-8: https://github.com/keras-team/keras/blob/1137074a9a1c237473f2fe57ab277c697892c6f1/keras/saving/serialization_lib.py#L154

An arbitrary byte array, as created by torch.save may fail decoding as it is not a valid utf-8 string, and will cause an exception on saving.

I think that the real issue here is the decoding of arbitrary bytes to a string, but I assume this is something that will be a hard to fix as it will create format compatibility issues, as it seems to date back a few years ago to keras 2.

Maybe the torch model buffer can be dumped to a json string and read from it by default, or dumped to str and then use ast.literal_eval or some other safe binary<->string conversion.

SuryanarayanaY commented 6 months ago

Hi @gilfree , Thanks for reporting. It would be helpful if you could submit a minimal reproducible snippet to check the behaviour. Thanks!

gilfree commented 6 months ago

To reproduce, add the following to the file: keras/utils/torch_utils_test.py:

 def test_save_load(self):
     import keras
     class M(keras.Model):
         def __init__(self,channels=10, **kwargs):
             super().__init__()
             self.sequence = torch.nn.Sequential(
                 torch.nn.Conv2d(1, channels, kernel_size=(3, 3)),
             )
         def call(self, x):
             return self.sequence(x)

         def get_config(self):
             return self.sequence.get_config()
     m = M()
     x=torch.ones((10,1,28, 28))
     m(x)
     m.save('model.keras')

Run with:

CUDA_VISIBLE_DEVICES=-1 KERAS_BACKEND=torch pytest -vvv -vs keras/utils/torch_utils_test.py -k save_load 

This is an artificial example as one will probably not write this get config, just to demonstrate the issue in the shortest way I have found.

SuryanarayanaY commented 4 months ago

Hi @gilfree ,

I have added the code snippet you provided to the file: keras/utils/torch_utils_test.py and replicated the reported error. Attached logs below.

(keras-jax) suryanarayanay-macbookpro:keras suryanarayanay$ CUDA_VISIBLE_DEVICES=-1 KERAS_BACKEND=torch pytest -vvv -vs keras/utils/torch_utils_test.py -k save_load 

===================================================================== test session starts ======================================================================
platform darwin -- Python 3.10.13, pytest-7.4.2, pluggy-1.3.0 -- /Users/suryanarayanay/miniconda/envs/keras-jax/bin/python
cachedir: .pytest_cache
rootdir: /Users/suryanarayanay/keraswork
configfile: pyproject.toml
plugins: cov-4.1.0, anyio-4.1.0
collected 0 items                                                                                                                                              

==================================================================== no tests ran in 0.00s =====================================================================
ERROR: file or directory not found: keras/utils/torch_utils_test.py

(keras-jax) suryanarayanay-macbookpro:keras suryanarayanay$ 
(keras-jax) suryanarayanay-macbookpro:keras suryanarayanay$ CUDA_VISIBLE_DEVICES=-1 KERAS_BACKEND=torch pytest -vvv -vs /Users/suryanarayanay/keraswork/keras/utils/torch_utils_test.py -k save_load 
===================================================================== test session starts ======================================================================
platform darwin -- Python 3.10.13, pytest-7.4.2, pluggy-1.3.0 -- /Users/suryanarayanay/miniconda/envs/keras-jax/bin/python
cachedir: .pytest_cache
rootdir: /Users/suryanarayanay/keraswork
configfile: pyproject.toml
plugins: cov-4.1.0, anyio-4.1.0
collected 15 items / 14 deselected / 1 selected                                                                                                                

utils/torch_utils_test.py::TorchUtilsTest::test_save_load FAILED

=========================================================================== FAILURES ===========================================================================
________________________________________________________________ TorchUtilsTest.test_save_load _________________________________________________________________

self = <keras.utils.torch_utils_test.TorchUtilsTest testMethod=test_save_load>

    def test_save_load(self):
        import keras
        class M(keras.Model):
            def __init__(self,channels=10, **kwargs):
                super().__init__()
                self.sequence = torch.nn.Sequential(
                    torch.nn.Conv2d(1, channels, kernel_size=(3, 3)),
                )
            def call(self, x):
                return self.sequence(x)

            def get_config(self):
                return self.sequence.get_config()
        m = M()
        x=torch.ones((10,1,28, 28))
        m(x)
>       m.save('model.keras')

utils/torch_utils_test.py:211: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
utils/traceback_utils.py:113: in error_handler
    return fn(*args, **kwargs)
models/model.py:302: in save
    return saving_api.save_model(self, filepath, overwrite, **kwargs)
saving/saving_api.py:100: in save_model
    saving_lib.save_model(model, filepath)
saving/saving_lib.py:92: in save_model
    _save_model_to_fileobj(model, f, weights_format)
saving/saving_lib.py:97: in _save_model_to_fileobj
    serialized_model_dict = serialize_keras_object(model)
saving/serialization_lib.py:239: in serialize_keras_object
    inner_config = _get_class_or_fn_config(obj)
saving/serialization_lib.py:373: in _get_class_or_fn_config
    return serialize_dict(config)
saving/serialization_lib.py:385: in serialize_dict
    return {key: serialize_keras_object(value) for key, value in obj.items()}
saving/serialization_lib.py:385: in <dictcomp>
    return {key: serialize_keras_object(value) for key, value in obj.items()}
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

obj = b'PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x12\x00ar...f6\n\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00PK\x05\x06\x00\x00\x00\x00\x06\x00\x06\x00~\x01\x00\x00x\t\x00\x00\x00\x00'

    @keras_export(
        [
            "keras.saving.serialize_keras_object",
            "keras.utils.serialize_keras_object",
        ]
    )
    def serialize_keras_object(obj):
        """Retrieve the config dict by serializing the Keras object.

        `serialize_keras_object()` serializes a Keras object to a python dictionary
        that represents the object, and is a reciprocal function of
        `deserialize_keras_object()`. See `deserialize_keras_object()` for more
        information about the config format.

        Args:
            obj: the Keras object to serialize.

        Returns:
            A python dict that represents the object. The python dict can be
            deserialized via `deserialize_keras_object()`.
        """
        if obj is None:
            return obj

        if isinstance(obj, PLAIN_TYPES):
            return obj

        if isinstance(obj, (list, tuple)):
            config_arr = [serialize_keras_object(x) for x in obj]
            return tuple(config_arr) if isinstance(obj, tuple) else config_arr
        if isinstance(obj, dict):
            return serialize_dict(obj)

        # Special cases:
        if isinstance(obj, bytes):
            return {
                "class_name": "__bytes__",
>               "config": {"value": obj.decode("utf-8")},
            }
E           UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte

saving/serialization_lib.py:154: UnicodeDecodeError
=================================================================== short test summary info ====================================================================
FAILED utils/torch_utils_test.py::TorchUtilsTest::test_save_load - UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
=============================================================== 1 failed, 14 deselected in 3.41s ===============================================================
(keras-jax) suryanarayanay-macbookpro:keras suryanarayanay$ 
SuryanarayanaY commented 4 months ago

Hi @gilfree ,

This might be due to the reason that Torch objects lacks get_config methods for serialization. This seems addressed in this comment.

Please check it once and come back.

github-actions[bot] commented 4 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

gilfree commented 4 months ago

Iti s correct that torch objects lack get_config. The comment above is the cause of the bug, not the solution to it.

As stated there, Torch modules are automatically wrapped by TorchModuleWrapper. The issue is that the wrapping is buggy. The wrapping "fallbacks" to serialization, which actually might be a good idea, but the serialization code contains a bug.

This makes it impossible to use the auto wrapping of torch modules within keras model to work with save/load - since the user cannot implement a correct get_config in the autowrapping case.

That makes the torch autowrapping pretty useless as I see it - you can't resume from a checkpoint of a model for example.