ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.19k stars 5.48k forks source link

[Tune] `Tuner.restore` does not fully work with WandB Callback #38894

Open olipinski opened 11 months ago

olipinski commented 11 months ago

What happened + What you expected to happen

When using Tuner.restore, it will update all other paths to match the current directory. This makes sense when moving checkpoints across machines.

However, it does not update (nor do I know how it could do so automatically) the paths set in Callbacks, for example the WandB Callback, which has a api_key_path parameter. Maybe it would be possible to allow for an override of certain parameters?

Versions / Dependencies

Ray 2.6.1 Python 3.10

Reproduction script

import os
import torch
from ray.air.integrations.wandb import WandbLoggerCallback
from torch import nn

from ray import air, tune

path = "./"

class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
        )

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))

tuner = tune.Tuner(
    MyTrainableClass,
    param_space={"input_size": 64},
    run_config=air.RunConfig(
        stop={"training_iteration": 20},
        checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
        callbacks=[
            WandbLoggerCallback(
                api_key_file=os.path.join(path, "wandb_api.key"),
                project="test",
                group="test_group",
                log_config=True,
            )
        ],
    ),
)
tuner.fit()

Issue Severity

Low: It annoys or frustrates me.

woshiyyya commented 10 months ago

@xwjiang2010 Can you take a look?