ZPZhou-lab / tfkan

The tensorflow implementation of KANs
MIT License
44 stars 9 forks source link

Loading a KAN model doesn't work (Fixed), and saving as a Tensorflow Lite model doesn't work #11

Closed ilmari99 closed 2 months ago

ilmari99 commented 2 months ago

Hello,

Thanks very much for this project, it's great! However I ran in to two different problems with this script:

import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

def convert_model_to_tflite(file_path : str, output_file : str):
    print("Converting '{}' to '{}'".format(file_path, output_file))

    model = tf.keras.models.load_model(file_path, compile=True)
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
    ]

    tflite_model = converter.convert()
    with open(output_file, "wb") as f:
        f.write(tflite_model)
    return output_file

if __name__ == "__main__":
    X = np.random.rand(100, 10)
    y = np.random.rand(100)

    X = tf.convert_to_tensor(X)
    y = tf.convert_to_tensor(y)

    kan_model = tf.keras.Sequential([
        DenseKAN(8),
        DenseKAN(8),
        DenseKAN(1),
    ])

    kan_model.compile(optimizer="adam", loss="mse")
    kan_model.fit(X, y, epochs=5)

    # Save the model
    kan_model.save("kan_model.keras")
    print(f"Saved model to 'kan_model.keras'")

    # Load the model
    kan_model = tf.keras.models.load_model("kan_model.keras")
    print(f"Loaded model from 'kan_model.keras'")

    # Convert the model to tflite
    convert_model_to_tflite("kan_model.keras", "kan_model.tflite")

First, loading a KAN model this way doesn't work for me out of the box. The script fails on line kan_model = tf.keras.models.load_model("kan_model.keras") with an error:

Traceback (most recent call last):
  File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 42, in <module>
    kan_model = tf.keras.models.load_model("kan_model.keras")
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_api.py", line 254, in load_model
    return saving_lib.load_model(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 281, in load_model
    raise e
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 246, in load_model
    model = deserialize_keras_object(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 728, in deserialize_keras_object
    instance = cls.from_config(inner_config)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/engine/sequential.py", line 466, in from_config
    layer = layer_module.deserialize(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/layers/serialization.py", line 276, in deserialize
    return serialization_lib.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 731, in deserialize_keras_object
    instance.build_from_config(build_config)
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 2331, in build_from_config
    self.build(input_shape)
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 83, in build
    raise ValueError(f"expected basis_activation to be str or callable, found {type(self.basis_activation)}")
ValueError: expected basis_activation to be str or callable, found <class 'tensorflow.python.trackable.data_structures._DictWrapper'>

I fixed this by replacing the line "basis_activation": self.basis_activation with "basis_activation": tf.keras.activations.serialize(self.basis_activation) in the get_configmethod of DenseKAN.

When the above error is fixed, I face a second problem: converting the model to TFLite. This, I haven't fixed, and trying to convert the model to TFLite raises an error:

Traceback (most recent call last):
  File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 47, in <module>
    convert_model_to_tflite("kan_model.keras", "kan_model.tflite")
  File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 16, in convert_model_to_tflite
    tflite_model = converter.convert()
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1139, in wrapper
    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1093, in _convert_and_export_metrics
    result = convert_func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1606, in convert
    self._freeze_keras_model()
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 215, in wrapper
    raise error from None  # Re-throws the exception.
    ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1553, in _freeze_keras_model
    concrete_func = func.get_concrete_function()
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1227, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1197, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize
    self._concrete_variable_creation_fn = tracing_compilation.trace_function(
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    concrete_function = _maybe_define_function(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    concrete_function = _create_concrete_function(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    traced_func_graph = func_graph_module.func_graph_from_py_func(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/tflite_keras_util.py", line 190, in _wrapped_model
    outputs = model(inputs, training=False)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 102, in call
    inputs, orig_shape = self._check_and_reshape_inputs(inputs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 127, in _check_and_reshape_inputs
    ndim = len(shape)
           ^^^^^^^^^^
TypeError: Exception encountered when calling layer 'sequential' (type Sequential).

len is not well defined for a symbolic Tensor (Shape:0). Please call `x.shape` rather than `len(x)` for shape information.

Call arguments received by layer 'sequential' (type Sequential):
  • inputs=tf.Tensor(shape=(None, 10), dtype=float64)
  • training=False
  • mask=None

Is there something I could do about this error?

ilmari99 commented 2 months ago

I figured why it didn't work: TF Lite calculates shapes using SymboliTensors which don't work with len(). I opened a pull request to fix/circumvent this: https://github.com/ZPZhou-lab/tfkan/pull/12