Thanks very much for this project, it's great! However I ran in to two different problems with this script:
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN
def convert_model_to_tflite(file_path : str, output_file : str):
print("Converting '{}' to '{}'".format(file_path, output_file))
model = tf.keras.models.load_model(file_path, compile=True)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
]
tflite_model = converter.convert()
with open(output_file, "wb") as f:
f.write(tflite_model)
return output_file
if __name__ == "__main__":
X = np.random.rand(100, 10)
y = np.random.rand(100)
X = tf.convert_to_tensor(X)
y = tf.convert_to_tensor(y)
kan_model = tf.keras.Sequential([
DenseKAN(8),
DenseKAN(8),
DenseKAN(1),
])
kan_model.compile(optimizer="adam", loss="mse")
kan_model.fit(X, y, epochs=5)
# Save the model
kan_model.save("kan_model.keras")
print(f"Saved model to 'kan_model.keras'")
# Load the model
kan_model = tf.keras.models.load_model("kan_model.keras")
print(f"Loaded model from 'kan_model.keras'")
# Convert the model to tflite
convert_model_to_tflite("kan_model.keras", "kan_model.tflite")
First, loading a KAN model this way doesn't work for me out of the box.
The script fails on line kan_model = tf.keras.models.load_model("kan_model.keras") with an error:
Traceback (most recent call last):
File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 42, in <module>
kan_model = tf.keras.models.load_model("kan_model.keras")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_api.py", line 254, in load_model
return saving_lib.load_model(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 281, in load_model
raise e
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 246, in load_model
model = deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 728, in deserialize_keras_object
instance = cls.from_config(inner_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/engine/sequential.py", line 466, in from_config
layer = layer_module.deserialize(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/layers/serialization.py", line 276, in deserialize
return serialization_lib.deserialize_keras_object(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/saving/serialization_lib.py", line 731, in deserialize_keras_object
instance.build_from_config(build_config)
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 2331, in build_from_config
self.build(input_shape)
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 83, in build
raise ValueError(f"expected basis_activation to be str or callable, found {type(self.basis_activation)}")
ValueError: expected basis_activation to be str or callable, found <class 'tensorflow.python.trackable.data_structures._DictWrapper'>
I fixed this by replacing the line
"basis_activation": self.basis_activation with
"basis_activation": tf.keras.activations.serialize(self.basis_activation) in the get_configmethod of DenseKAN.
When the above error is fixed, I face a second problem: converting the model to TFLite. This, I haven't fixed, and trying to convert the model to TFLite raises an error:
Traceback (most recent call last):
File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 47, in <module>
convert_model_to_tflite("kan_model.keras", "kan_model.tflite")
File "/home/ilmari/python/RLFramework/BlokusPentobi/test_kan_minimal.py", line 16, in convert_model_to_tflite
tflite_model = converter.convert()
^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1139, in wrapper
return self._convert_and_export_metrics(convert_func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1093, in _convert_and_export_metrics
result = convert_func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1606, in convert
self._freeze_keras_model()
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 215, in wrapper
raise error from None # Re-throws the exception.
^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1553, in _freeze_keras_model
concrete_func = func.get_concrete_function()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1227, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1197, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize
self._concrete_variable_creation_fn = tracing_compilation.trace_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
concrete_function = _maybe_define_function(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
concrete_function = _create_concrete_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
traced_func_graph = func_graph_module.func_graph_from_py_func(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tensorflow/lite/python/tflite_keras_util.py", line 190, in _wrapped_model
outputs = model(inputs, training=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 102, in call
inputs, orig_shape = self._check_and_reshape_inputs(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ilmari/python/RLFramework/.venv/lib/python3.11/site-packages/tfkan/layers/dense.py", line 127, in _check_and_reshape_inputs
ndim = len(shape)
^^^^^^^^^^
TypeError: Exception encountered when calling layer 'sequential' (type Sequential).
len is not well defined for a symbolic Tensor (Shape:0). Please call `x.shape` rather than `len(x)` for shape information.
Call arguments received by layer 'sequential' (type Sequential):
• inputs=tf.Tensor(shape=(None, 10), dtype=float64)
• training=False
• mask=None
I figured why it didn't work: TF Lite calculates shapes using SymboliTensors which don't work with len(). I opened a pull request to fix/circumvent this: https://github.com/ZPZhou-lab/tfkan/pull/12
Hello,
Thanks very much for this project, it's great! However I ran in to two different problems with this script:
First, loading a KAN model this way doesn't work for me out of the box. The script fails on line
kan_model = tf.keras.models.load_model("kan_model.keras")
with an error:I fixed this by replacing the line
"basis_activation": self.basis_activation
with"basis_activation": tf.keras.activations.serialize(self.basis_activation)
in theget_config
method of DenseKAN.When the above error is fixed, I face a second problem: converting the model to TFLite. This, I haven't fixed, and trying to convert the model to TFLite raises an error:
Is there something I could do about this error?