ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

Teardown after trainer.fit() takes exceptionally long when using RayStrategy with large models #207

Closed MarkusSpanring closed 1 year ago

MarkusSpanring commented 1 year ago

I noticed that teardown of the ray workers takes exceptionally long when I use RayStrategy to train slightly larger models such as torchvision.models.resnet18, torch.nn.LSTM or torch.nn.transformer.

I have used the example from ray_ddp_example.py and replaced the model with a ResNet and the Data with CIFAR10 to reproduce the issue. When I run vanilla PTL (set run="ptl") the model finishes as expected. However, with run="tune" or run="ptl_ray" the teardown takes over a minute. I also noticed that memory usage is increasing during teardown when using RayStrategy

Is there something wrong in my setup or is this the expected behavior? If you need any additional information please let me know.

Thanks in advance!

import os

from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule
from pathlib import Path

import pytorch_lightning as pl
import ray
from ray import tune
from ray_lightning.tune import TuneReportCallback, get_tune_resources
from ray_lightning import RayStrategy

import torch
import torch.nn.functional as F
from torchvision.models import resnet18

import torchmetrics

class LightningCIFAR10Classifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningCIFAR10Classifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()
        self.lr = config["lr"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.accuracy = torchmetrics.Accuracy()

        self.model = resnet18()

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

def train_mnist(config,
                data_dir=None,
                num_epochs=1,
                num_workers=1,
                use_gpu=True,
                callbacks=None,
                strategy="ddp"):
    # Make sure data is downloaded on all nodes.
    def download_data():
        from filelock import FileLock
        with FileLock(os.path.join(data_dir, ".lock")):
            CIFAR10DataModule(data_dir=data_dir).prepare_data()

    model = LightningCIFAR10Classifier(config, data_dir)

    callbacks = callbacks or []

    if strategy == "ray":
        strategy = RayStrategy(
            num_workers=num_workers, use_gpu=use_gpu, init_hook=download_data
        )

    gpus = None
    if strategy == "ddp":
        gpus = [0]

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=callbacks,
        gpus=gpus,
        progress_bar_refresh_rate=1,
        strategy=strategy
    )
    dm = CIFAR10DataModule(
        data_dir=data_dir, num_workers=1, batch_size=512)
    trainer.fit(model, dm)

def tune_mnist(data_dir,
               num_samples=1,
               num_epochs=1,
               num_workers=1,
               use_gpu=True):
    config = {
        "lr": tune.loguniform(1e-4, 1e-1)
    }

    # Add Tune callback.
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    callbacks = [TuneReportCallback(metrics, on="validation_end")]
    trainable = tune.with_parameters(
        train_mnist,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_workers=num_workers,
        use_gpu=use_gpu,
        callbacks=callbacks,
        strategy="ray")
    analysis = tune.run(
        trainable,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        resources_per_trial=get_tune_resources(
            num_workers=num_workers, use_gpu=use_gpu),
        name="tune_mnist")

    print("Best hyperparameters found were: ", analysis.best_config)

if __name__ == "__main__":

    # run = "tune"
    # run = "ptl_ray"
    run = "ptl"

    data_dir = Path("mnist_data")
    data_dir.mkdir(parents=True, exist_ok=True)
    config = {"lr": 0.01}

    if run == "tune":
        ray.init()
        tune_mnist(data_dir, num_samples=1, num_epochs=1, num_workers=1, use_gpu=True)

    if run == "ptl_ray":
        ray.init()
        train_mnist(
            config, data_dir, num_epochs=1, num_workers=1, use_gpu=True, strategy="ray"
        )

    if run == "ptl":
        train_mnist(config, data_dir, num_epochs=1, num_workers=1, use_gpu=True)

Below the conda environment that I use

name: torch
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - blas=1.0=mkl
  - ca-certificates=2022.6.15
  - certifi=2022.6.15
  - cudatoolkit=11.6.0
  - intel-openmp=2022.0.1
  - ld_impl_linux-64=2.38
  - libffi=3.3
  - libgcc-ng=11.2.0
  - libgomp=11.2.0
  - libstdcxx-ng=11.2.0
  - mkl=2022.0.1
  - ncurses=6.3
  - openssl=1.1.1o
  - pip=22.1.2
  - python=3.9.12
  - python_abi=3.9=2_cp39
  - pytorch=1.12.1=py3.9_cuda11.6_cudnn8.3.2_0
  - pytorch-mutex=1.0=cuda
  - readline=8.1.2
  - setuptools=63.4.1
  - sqlite=3.39.2
  - tk=8.6.12
  - typing_extensions=4.3.0
  - tzdata=2022a
  - wheel=0.37.1
  - xz=5.2.5
  - zlib=1.2.12

  - pip:
      - pytorch-lightning==1.6.5
      - ray==2.0
      - ray[tune]==2.0
      - ray_lightning==0.3
      - numpy==1.22.4
      - flatten_json==0.1.13
      - overrides==6.2.0
      - gitpython==3.1.27
      - gputil==1.4.0
      - psutil==5.9.1
      - torchvision==0.13.1
      - torchtext==0.13.1
      - torchdata==0.4.1
      - lightning-bolts==0.5.0
      - torchmetrics==0.9.3
      - torch-fidelity==0.3.0
      - sacremoses==0.0.41
MarkusSpanring commented 1 year ago

After some digging, I found that moving results to cpu when on rank zero leads to this bottleneck. Is there actually anything in _RayOuput here that still lives on gpu such that it needs to be moved to cpu?

FYI: When I replace

if trainer.strategy.local_rank == 0:
    return move_data_to_device(results, "cpu")

with

if trainer.strategy.local_rank == 0:
    return results

it seems to work fine. Is there a case when the first case is needed?