ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.26k stars 5.63k forks source link

[Tune] TensorflowCheckpoint cannot save subclassed Keras model because it uses legacy H5 format #44804

Open giulatona opened 5 months ago

giulatona commented 5 months ago

What happened + What you expected to happen

While tuning the hyper parameters of a custom keras model (subclass of keras.model) and requesting to save a checkpoint at epoch end through ReportCheckpointCallback the process fails. In particular ray.train.TensorflowCheckpoint tries to save the model using model.save(). This results, in the following error:

(train_mnist pid=1990753) /home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via model.save(). This file format is considered legacy. We recommend using instead the native Keras format, e.g. model.save('my_model.keras'). (train_mnist pid=1990753) saving_api.save_model( 2024-04-17 17:47:09,632 ERROR tune_controller.py:1374 -- Trial task failed for trial train_mnist_bc3d6_00000 Traceback (most recent call last): File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future result = ray.get(future) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper return fn(*args, kwargs) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, *kwargs) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/_private/worker.py", line 2624, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(NotImplementedError): ray::ImplicitFunc.train() (pid=1990753, ip=150.145.127.41, actor_id=027e77bf46021d121a9c930f01000000, repr=train_mnist) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 342, in train raise skipped from exception_cause(skipped) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/air/_internal/util.py", line 88, in run self._ret = self._target(self._args, self._kwargs) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 115, in training_func=lambda: self._trainable_func(self.config), File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 332, in _trainable_func output = fn() File "/home/giuseppe/temp_repositories/tray_ray_tune/tune_mnist_subclass.py", line 56, in train_mnist model.fit( File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler raise e.with_traceback(filtered_tb) from None File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/air/integrations/keras.py", line 53, in on_epoch_end self._handle(logs, "epoch_end") File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/air/integrations/keras.py", line 163, in _handle checkpoint = TensorflowCheckpoint.from_model(self.model) File "/home/giuseppe/.local/share/virtualenvs/tray_ray_tune-7yFR_ThX/lib/python3.10/site-packages/ray/train/tensorflow/tensorflow_checkpoint.py", line 60, in from_model model.save(os.path.join(tempdir, filename)) NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using save_weights.

This can be traced to the call to model.save in TensorflowCheckpoint that somehow results in save_format == h5.

Versions / Dependencies

Ray[tune] version 2.10.0 Keras version 2.15 Tensorflow version 2.15 Python version 3.11 OS: Ubuntu 22.04

Reproduction script

import argparse
import os

from filelock import FileLock
import keras
from keras.datasets import mnist

import ray
from ray import train, tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.air.integrations.keras import ReportCheckpointCallback

@keras.saving.register_keras_serializable()
class SimpleModel(keras.Model):
    def __init__(self, hidden):
        super(SimpleModel, self).__init__()

        self.hidden = hidden
        self.flatten = keras.layers.Flatten(input_shape=(28, 28))
        self.dense1 = keras.layers.Dense(self.hidden, activation="relu")
        self.dropout = keras.layers.Dropout(0.2)
        self.dense2 = keras.layers.Dense(10, activation="softmax")

    def call(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({"hidden": self.hidden})
        return cfg

def train_mnist(config):
    # https://github.com/tensorflow/tensorflow/issues/32159
    # import tensorflow as tf

    batch_size = 128
    num_classes = 10
    epochs = 12

    with FileLock(os.path.expanduser("~/.data.lock")):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = SimpleModel(config["hidden"])

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=keras.optimizers.SGD(learning_rate=config["lr"], momentum=config["momentum"]),
        metrics=["accuracy"],
    )

    callbacks = [ReportCheckpointCallback(metrics={"mean_accuracy": "accuracy"})]

    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        verbose=0,
        validation_data=(x_test, y_test),
        callbacks=callbacks,
    )

def tune_mnist():
    sched = AsyncHyperBandScheduler(
        time_attr="training_iteration", max_t=400, grace_period=20
    )

    tuner = tune.Tuner(
        tune.with_resources(train_mnist, resources={"cpu": 12, "gpu": 1}),
        tune_config=tune.TuneConfig(
            metric="mean_accuracy",
            mode="max",
            scheduler=sched,
            num_samples=1,
        ),
        run_config=train.RunConfig(
            name="exp",
            stop={"mean_accuracy": 0.99},
        ),
        param_space={
            "threads": 2,
            "lr": tune.uniform(0.001, 0.1),
            "momentum": tune.uniform(0.1, 0.9),
            "hidden": tune.randint(32, 512),
        },
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)

tune_mnist()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

woshiyyya commented 5 months ago

@giulatona instead of using ReportCheckpointCallback, Can you save the model checkpoint yourself and use ray.train.report to upload the checkpoint?

Here's the user guide: https://docs.ray.io/en/master/train/user-guides/checkpoints.html

giulatona commented 5 months ago

Since I am using keras model.fit what you suggest would mean that I would have to write my own callback instead of using ReportCheckpointCallback. I do not think that it would be useful. Also, it is a problem with the code in ReportCheckpointCallback, or better a problem with ray.train.TensorflowCheckpoint

woshiyyya commented 4 months ago

ReportCheckpointCallback is a default solutioin provided by ray team, and it should cover common cases. But if it requires customization, users need to write their own report callback (e.g. save model in keras native format in your case)

giulatona commented 4 months ago

I do not want to customise it, it just does not work and I tried to provide a reason for that in the issue. I do not care how the model are saved