ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.87k stars 5.76k forks source link

[Checkpoint: AIR] Saved checkpoints folders does not include correct training iteration number. #29458

Open n30111 opened 2 years ago

n30111 commented 2 years ago

What happened + What you expected to happen

While enabling the frequency parameter In the Keras Callback (from ray.air.callbacks.keras import Callback), the checkpoints folder does not include the correct training iteration number.

If we set frequency=1, then the checkpoints follow the naming convention checkpoint_{(iteration-1):06d}, but if we set frequency>1, the saved checkpoint folder does not have any info about the iteration number, and the checkpoints are saved with consecutive folder naming convention. This is because of the way checkpoints folder are created here : https://github.com/ray-project/ray/blob/master/python/ray/train/_internal/checkpoint.py#L228 . As it simply increment the self._latest_checkpoint_id without considering the frequency parameter.

While using frequency=1

# outputs checkpoints: ['checkpoint_000002', 'checkpoint_000004', 'checkpoint_000000', 'checkpoint_000003', 'checkpoint_000001']

While using frequency=2

# output checkpoints: ['checkpoint_000000', 'checkpoint_000001']

But ideally these numbering should be ['checkpoint_000002', 'checkpoint_000004']

Versions / Dependencies

2.0.0

Reproduction script

Following script which is a minor modification of the test: https://github.com/ray-project/ray/blob/releases/2.0.0/python/ray/air/tests/test_keras_callback.py can be used to reproduce the bug.

from pathlib import Path

import numpy as np
import tensorflow as tf

import ray
from ray.air import session
from ray.air.callbacks.keras import Callback
from ray.air.constants import MODEL_KEY
from ray.train.constants import TRAIN_DATASET_KEY
from ray.air.config import RunConfig, ScalingConfig
from ray.train.tensorflow import (
    TensorflowTrainer,
    prepare_dataset_shard,
    TensorflowPredictor,
)

def get_dataset(a=5, b=10, size=1000):
    items = [i / size for i in range(size)]
    dataset = ray.data.from_items([{"x": x, "y": a * x + b} for x in items])
    return dataset

def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    return model

def train_func(config: dict, ckpt_freq=1):
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )

    dataset = session.get_dataset_shard("train")

    def to_tf_dataset(dataset, batch_size):
        def to_tensor_iterator():
            for batch in dataset.iter_tf_batches(
                batch_size=batch_size, dtypes=tf.float32
            ):
                yield batch["x"], batch["y"]

        output_signature = (
            tf.TensorSpec(shape=(None), dtype=tf.float32),
            tf.TensorSpec(shape=(None), dtype=tf.float32),
        )
        tf_dataset = tf.data.Dataset.from_generator(
            to_tensor_iterator, output_signature=output_signature
        )
        return prepare_dataset_shard(tf_dataset)

    tf_dataset = to_tf_dataset(dataset=dataset, batch_size=32)
    multi_worker_model.fit(tf_dataset, callbacks=[Callback(frequency=ckpt_freq)], epochs=config["epochs"])

def test_keras_callback_e2e(ckpt_freq=1):
    epochs = 5
    config = {
        "epochs": epochs,
    }
    import tempfile
    tempdir = tempfile.TemporaryDirectory().name
    print(tempdir)
    trainer = TensorflowTrainer(
        train_loop_per_worker=lambda config: train_func(config, ckpt_freq=ckpt_freq),
        train_loop_config=config,
        scaling_config=ScalingConfig(num_workers=2),
        datasets={TRAIN_DATASET_KEY: get_dataset()},
        run_config=RunConfig(local_dir=tempdir)
    )
    checkpoint = trainer.fit().checkpoint
    base_path = Path(checkpoint._local_path).parent
    ckpts = [ckpt_dir.name for ckpt_dir in base_path.iterdir() if "checkpoint_00000" in str(ckpt_dir)]
    return ckpts

Issue Severity

High: It blocks me from completing my task.

amogkam commented 2 years ago

Thanks for reporting the issue @n30111! Indeed, definitely something we should fix.

I think we should switch to using the Tune Session API internally instead of tune.checkpoint_dir, and then on the Tune side, it can fill in the checkpoint step the training_iteration in the corresponding metrics. cc @xwjiang2010 @Yard1

dumpmemory commented 2 years ago

There is the same issue for HuggingfaceTrainer, when using steps for saving frequency, like 1000 steps, the first checkpoint is checkpoint 00000, not checkpoint1000.

Yard1 commented 2 years ago

How is this impacting workloads, aside from the Keras callback not saving the epoch? As far as I understand, the most important thing is that we have an incremental counter for checkpoints. The actual epoch/iteration number should be saved inside the checkpoint itself (which is indeed the case with Huggingface, but not with the Keras callback).

dumpmemory commented 2 years ago

How is this impacting workloads, aside from the Keras callback not saving the epoch? As far as I understand, the most important thing is that we have an incremental counter for checkpoints. The actual epoch/iteration number should be saved inside the checkpoint itself (which is indeed the case with Huggingface, but not with the Keras callback).

But keep the checkpoint number consistent with Huggingface checkpoint number will be more connivence for managing checkpoints

xwjiang2010 commented 2 years ago

@amogkam not exactly sure that I followed. How does Tune Session know about the specific application details (freq etc)?

dumpmemory commented 2 years ago

I haven't set checkpoint_frequency in CheckpointConfig

n30111 commented 1 year ago

@amogkam any update on this issue?

anyscalesam commented 5 months ago

@justinvyu does #36220 resolve this?