ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

[Tune] PBT/PB2 doesn't work correctly with Ray Lightning #145

Open yinweisu opened 2 years ago

yinweisu commented 2 years ago

When using PBT/PB2, I received the following error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

This issue happens after the trial is paused and resumed. I was able to reproduce this issue with some modifications on the example provided by ray lightning:

"""Simple example using RayAccelerator and Ray Tune"""
import os
import tempfile

from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

import pytorch_lightning as pl
import ray
from ray import tune
from ray_lightning.tune import TuneReportCheckpointCallback, get_tune_resources
from ray_lightning import RayPlugin
from ray_lightning.tests.utils import LightningMNISTClassifier
from ray.tune.schedulers.pb2 import PB2

def train_mnist(config,
                checkpoint_dir=None,
                data_dir=None,
                num_epochs=10,
                num_workers=1,
                use_gpu=False,
                callbacks=None):
    # Make sure data is downloaded on all nodes.
    def download_data():
        from filelock import FileLock
        with FileLock(os.path.join(data_dir, ".lock")):
            MNISTDataModule(data_dir=data_dir).prepare_data()

    model = LightningMNISTClassifier(config, data_dir)

    callbacks = callbacks or []
    checkpoint_path = None
    if checkpoint_dir is not None:
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=callbacks,
        progress_bar_refresh_rate=0,
        plugins=[
            RayPlugin(
                num_workers=num_workers,
                use_gpu=use_gpu,
                init_hook=download_data)
        ])
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
    trainer.fit(model, dm, ckpt_path=checkpoint_path)

def tune_mnist(data_dir,
               num_samples=10,
               num_epochs=10,
               num_workers=1,
               use_gpu=False):
    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }
    scheduler = PB2(
                hyperparam_bounds= {
                    "lr": [1e-4, 1e-1]
                }
            )

    # Add Tune callback.
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    callbacks = [TuneReportCheckpointCallback(metrics, on="validation_end", filename="checkpoint")]
    trainable = tune.with_parameters(
        train_mnist,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_workers=num_workers,
        use_gpu=use_gpu,
        callbacks=callbacks)
    analysis = tune.run(
        trainable,
        scheduler=scheduler,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        resources_per_trial=get_tune_resources(
            num_workers=num_workers, use_gpu=use_gpu),
        name="tune_mnist")

    print("Best hyperparameters found were: ", analysis.best_config)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num-workers",
        type=int,
        help="Number of training workers to use.",
        default=1)
    parser.add_argument(
        "--use-gpu", action="store_true", help="Use GPU for training.")
    parser.add_argument(
        "--num-samples",
        type=int,
        default=10,
        help="Number of samples to tune.")
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=10,
        help="Number of epochs to train for.")
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing")
    parser.add_argument(
        "--address",
        required=False,
        type=str,
        help="the address to use for Ray")
    args, _ = parser.parse_known_args()

    num_epochs = 1 if args.smoke_test else args.num_epochs
    num_workers = 1 if args.smoke_test else args.num_workers
    use_gpu = False if args.smoke_test else args.use_gpu
    num_samples = 1 if args.smoke_test else args.num_samples

    if args.smoke_test:
        ray.init(num_cpus=2)
    else:
        ray.init(address=args.address)

    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    tune_mnist(data_dir, num_samples, num_epochs, num_workers, use_gpu)

The args I passed in: python3 test_ray_lightning.py --use-gpu --num-workers 2 --num-samples 4

Versions:

pytorch-lightning==1.5.10
torch==1.10.2
ray==1.12.0
ray-lightning==0.2.0
xwjiang2010 commented 2 years ago

Hi @yinweisu Thanks for reporting. I believe this is a valid issue. Was able to reproduce it on my set up as well. After a bit digging, it seems this is a known issue with ptl 1.5: https://github.com/PyTorchLightning/pytorch-lightning/discussions/11435 https://github.com/PyTorchLightning/pytorch-lightning/issues/12327

The solution is basically to upgrade to ptl 1.6. @amogkam Another datapoint that we should do the upgrade sooner than later.

yinweisu commented 2 years ago

Thanks! And yes, upgrade to ptl 1.6 soon would be awesome!

dynamicwebpaige commented 2 years ago

+1 for upgrading to PyTorch Lightning 1.6! Is there an estimate for when that work might occur?