Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.25k stars 3.38k forks source link

Multi-node training expects num_nodes and devices, but that may be variable in a slurm cluster. #13804

Open TheShadow29 opened 2 years ago

TheShadow29 commented 2 years ago

πŸ› Bug

Currently, Trainer requires num_nodes and devices, but this may be different across nodes. For instance, slurm may provide 1 node with 6 gpus, and 2 other nodes with 1 gpu each, for a total of 8 nodes. Right now, it gives the following error:

..../python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 118, in root_device
    return self.parallel_devices[self.local_rank]
IndexError: list index out of range
srun: error: <node-name>: tasks 6-7: Exited with exit code 1

To Reproduce

Note: SL_NUM_NODES being set externally

# dummy_run.py
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import socket
import datetime
from pytorch_lightning.utilities.rank_zero import _get_rank

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def main():

    hostname = socket.gethostname()
    ngpus = torch.cuda.device_count()

    num_nodes = int(os.environ.get("SL_NUM_NODES", 1))
    wsize = int(os.environ.get("SLURM_NTASKS", ngpus * num_nodes))
    grank = _get_rank()
    print(
        f"Hostname={hostname}",
        f"nGPUs in host={torch.cuda.device_count()}",
        f"Start Time={datetime.datetime.now()}",
        f"num_nodes={num_nodes}",
        f"nCPUs = {torch.multiprocessing.cpu_count()}",
        f"wSize={wsize}",
        f"grank={grank}",
    )

    train_data = DataLoader(RandomDataset(32, 6400000), batch_size=2)

    model = BoringModel()

    trainer = Trainer(
        devices=ngpus,
        num_nodes=num_nodes,
        accelerator="gpu",
        strategy="ddp",
        limit_train_batches=100000,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        log_every_n_steps=1
    )
    trainer.fit(model, train_dataloaders=train_data)

    print("DONE")

if __name__ == "__main__":
    main()

And here is the slurm script (need to add , ,

#!/bin/bash
#SBATCH --job-name=check_dummy_run
#SBATCH --nodes=3
#SBATCH --time=1:00:00
#SBATCH --partition=<partition-name>
#SBATCH --gpus=8
#SBATCH --ntasks=8
#SBATCH --cpus-per-task=10

source ~/miniconda/etc/profile.d/conda.sh
conda activate <env-name>

export SL_NUM_NODES=3
export PYTHONPATH=$(pwd)

srun python dummy_run.py

Expected behavior

Ideally, the world size should be provided by cluster environment, and the trainer should create subprocesses only based on number of gpus available in current node.

Environment

cc @awaelchli @tchaton @rohitgr7 @justusschock @kaushikb11 @akihironitta

awaelchli commented 2 years ago

Hi

Yes, this is a known limitation currently. While it was a true limitation in the past, today it is somewhat artificial. I opened a proposal #14078 which should pave the way to remove this limitation eventually.

After #14078, you would simply set devices="auto" or devices=-1 and then the actual number of devices can be different per node.

I'm removing the bug label because this can't really be delivered as a bug fix, and depends on the decision in #14078.