devanshkv / fetch

A set of deep learning models for FRB/RFI binary classification.
GNU General Public License v3.0
40 stars 32 forks source link

Migration from keras 2 to 3 #36

Open marcowenwolf opened 5 months ago

marcowenwolf commented 5 months ago

The updated requirements.txt invokes the installation of keras version 3. The version 3 API has changed and requires some code changes, see Migrating Keras 2 code to multi-backend Keras 3. Please migrate to keras 3.

FYI: tensorflow[and-cuda]==2.15.0.post1 still loads keras 2 version

Note: also the keras model construct has changed, which does not support deep nesting. Not sure if this is impacting fetch.

astrogewgaw commented 5 months ago

Yes, I faced this issue recently when using FETCH. If I run:

predict.py -n 4 -m a -c .

I get the following output:

Traceback (most recent call last):
  File "/data4/upanda/conda/bin/predict.py", line 4, in <module>
    __import__('pkg_resources').run_script('fetch==0.2.0', 'predict.py')
  File "/data4/upanda/conda/lib/python3.12/site-packages/pkg_resources/__init__.py", line 691, in run_script
    self.require(requires)[0].run_script(script_name, ns)
  File "/data4/upanda/conda/lib/python3.12/site-packages/pkg_resources/__init__.py", line 1530, in run_script
    exec(code, namespace, namespace)
  File "/data4/upanda/conda/lib/python3.12/site-packages/fetch-0.2.0-py3.12.egg/EGG-INFO/scripts/predict.py", line 79, in <module>
    model = get_model(args.model)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/data4/upanda/conda/lib/python3.12/site-packages/fetch-0.2.0-py3.12.egg/fetch/utils.py", line 92, in get_model
    model = model_from_json(j.read())
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data4/upanda/conda/lib/python3.12/site-packages/keras/src/models/model.py", line 575, in model_from_json
    return serialization_lib.deserialize_keras_object(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data4/upanda/conda/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py", line 694, in deserialize_keras_object
    cls = _retrieve_class_or_fn(
          ^^^^^^^^^^^^^^^^^^^^^^
  File "/data4/upanda/conda/lib/python3.12/site-packages/keras/src/saving/serialization_lib.py", line 812, in _retrieve_class_or_fn
    raise TypeError(
TypeError: Could not locate class 'Functional'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`.

and then it prints the full object config on my stdout.

marcowenwolf commented 4 months ago

A slightly better workaround is stated on keras page:

If you were accessing keras via tf.keras, there are no immediate changes until TensorFlow 2.16. 
TensorFlow 2.16+ will use Keras 3 by default. In TensorFlow 2.16+, to keep using Keras 2, you can
first install tf_keras, and then export the environment variable TF_USE_LEGACY_KERAS=1. This will
direct TensorFlow 2.16+ to resolve tf.keras to the locally-installed tf_keras package. Note that this
may affect more than your own code, however: it will affect any package importing tf.keras in
your Python process. To make sure your changes only affect your own code, you should use the 
tf_keras package. 

So include the tf_keras module in the requirements.txt file and 'export TF_USE_LEGACY_KERAS=1' prior running your code. Hopefully this will work until keras 3 is supported. requirements.txt

aweaver1fandm commented 2 months ago

In an attempt to try to get the ball rolling on this issue, I started to make code changes as per the migration guide referenced in an earlier post. After a few false starts I made some progress, but the problem here is going to be the models themselves. I am not sure if deep nesting is an issue or not yet ( but it probably is). What I am currently running into are other issues with the model. Namely that based on some errors I am seeing it would seem that some of the keyword arguments related to the model have changed from 2 to 3. In order to be keras 3 compatible, the "easiest" course of action would seem to be

  1. Update the code as per the migration guide
  2. Once on keras 3, re-train the models so that they are keras 3 compatible.

I don't know how the models were trained in the first place (what parameters were used) so I can't do that.

In the meantime, we use Spack to install Fetch. I am modifying my Spack build recipe to force upper limit on the version of tensorflow and keras so that they remain in a range where keras2 will be the version.

devanshkv commented 1 month ago

@aweaver1fandm, deep nesting is a big issue but this is how Keras did it back when we were making these models. I think moving these models to PyTorch would be a better and more stable solution. lmk your thoughts.

aweaver1fandm commented 1 month ago

At this point I would strongly agree that moving to Pytorch would be the way to go. As you said deep nesting is an issue. In addition, I have been running into some other tensorflow issues when it comes to just installing/running the Python packages. See for example these issues https://github.com/tensorflow/tensorflow/issues/61468 https://github.com/tensorflow/tensorflow/issues/64809 https://github.com/tensorflow/tensorflow/issues/63109

I have gotten to a point where I have adjusted your model and updated the code to Keras 3, but predict just hangs on the GPU and it's related to these issues. In this case if using the newest tensorflow keras may not run without a good bit of hacking to get all the dependencies installed. I've done some further digging and Google is developing something called Jax which is their new deep learning thing. It makes me wonder how much further development will really happen with tensorflow/keras.

As far as Pytorch goes, the Python code will have to be updated to use that instead of tensorflow. However, I did find this related to moving the model to Pytorch https://www.geeksforgeeks.org/how-to-convert-a-tensorflow-model-to-pytorch/ I would say that ideally the model should be re-trained from scratch in Pytorch. I also understand that it might be the case that the original training data is not available and so this may be an option to move to pytorch