cbrnr / sleepecg

Sleep stage detection using ECG
BSD 3-Clause "New" or "Revised" License
86 stars 22 forks source link

Cannot load classifiers with latest TensorFlow/Keras #214

Closed cbrnr closed 2 months ago

cbrnr commented 2 months ago

Keras has been updated to v3, which unfortunately means that the built-in classifiers that were created with v2 cannot be read anymore. The following example (from the documentation) demonstrates the issue:

from sleepecg import load_classifier

clf = load_classifier("ws-gru-mesa", "SleepECG")

This results in the following error:

Traceback (most recent call last):
  File "/home/clemens/Projects/sleepecg/mwe.py", line 3, in <module>
    clf = load_classifier("ws-gru-mesa", "SleepECG")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/clemens/Projects/sleepecg/sleepecg/classification.py", line 308, in load_classifier
    classifier = keras.models.load_model(f"{tmpdir}/classifier")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/clemens/.pyenv/versions/sleepecg/lib/python3.12/site-packages/keras/src/saving/saving_api.py", line 193, in load_model
    raise ValueError(
ValueError: File format not supported: filepath=/tmp/tmpea7_jhpo/classifier. Keras 3 only supports V3 `.keras` files and legacy H5 format files (`.h5` extension). Note that the legacy SavedModel format is not supported by `load_model()` in Keras 3. In order to reload a TensorFlow SavedModel as an inference-only layer in Keras 3, use `keras.layers.TFSMLayer(/tmp/tmpea7_jhpo/classifier, call_endpoint='serving_default')` (note that your `call_endpoint` might have a different name).

I think we may have to re-create the classifiers and save them in the new v3 format. WDYT @hofaflo?

Also, and this is probably a separate issue, pyproject.toml contains tensorflow as a dependency, and SleepECG is using tensorflow.keras. Should we switch to Keras instead (with the TensorFlow backend)? With the new Keras v3, it should be possible to transparently switch between different backends when using the Keras API (https://keras.io/keras_3/).