Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.38k forks source link

Can't resume automatically a job, ckpt_path="hpc" throws ValueError from the start #20347

Open F-Barto opened 3 weeks ago

F-Barto commented 3 weeks ago

Summary

When attempting to resume a job from where it left off before reaching wall-time on a SLURM cluster using PyTorch Lightning, the ckpt_path="hpc" option causes an error if no HPC checkpoint exists yet. This prevents the initial training run from starting.

Expected Behavior

.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)

Current Behavior

Using ckpt_path=None allows the job to start but doesn't resume from the HPC checkpoint when one is created.

If I use trainer.fit(model, datamodule=dm, ckpt_path=None), the SIGUSR1 is correctly catched and the checkpoint hpc_ckpt_1.ckpt correctly created. However the checkpoint is not used which is expected because we left ckpt_path=None.

requeing job 245038...
Requeued SLURM job: 245038
srun: Job step aborted: Waiting up to 62 seconds for job step to finish.
slurmstepd: error: *** JOB 245038 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***
slurmstepd: error: *** STEP 245038.0 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***

Using ckpt_path="hpc" throws an error if no HPC checkpoint is found, preventing the initial training run.

The logic of looking for and loading the hpc checkpoint from what I understood should be handled by setting ckpt_path="hpc" However, as can be seen in https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py#L193C1-L199C46 if an hpc ckpt is not found it throws an error and stops:

.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)

The issue is that for the very first training of course there would be no hpc checkpoint because we haven't started any training yet

Relevant issues

16639

What version are you seeing the problem on?

v2.4

How to reproduce the bug

dummy_model.py

import os
import torch
import lightning as L
from torch.utils.data import Dataset
from lightning.pytorch.callbacks import ModelCheckpoint
import argparse
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data.distributed import DistributedSampler
import pickle
import signal
from lightning.pytorch.plugins.environments import SLURMEnvironment
import time

class DummyDataset(Dataset):
    def __init__(self, size=100000):
        self.size = size
        self.data = torch.randn(size, 10)
        self.labels = torch.randint(0, 2, (size,))
        self.current_index = 0

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        print(f"Accessing index: {idx}, {self.data[idx]}")
        return self.data[idx], self.labels[idx]

class DPAwareDataLoader(StatefulDataLoader, Stateful):
    def __init__(self, dataset: Dataset, batch_size: int, sampler=None, **kwargs):
        super().__init__(dataset, batch_size=batch_size, sampler=sampler, **kwargs)
        self._rank_id = f"dp_rank_{sampler.rank if sampler else 0}"
        print(self._rank_id, " initialized")

    def state_dict(self):
        print(f"self._rank_id: ", f"{super().state_dict()}")
        return {self._rank_id: super().state_dict()}

    def load_state_dict(self, state_dict):
        if not state_dict:
            return
        if self._rank_id not in state_dict:
            print(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
            return
        print(f"DataLoader state loading for dp rank {self._dp_rank}")
        super().load_state_dict(state_dict[self._rank_id])

class DummyDataModule(L.LightningDataModule, Stateful):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = None
        self.dataloader = None

    def setup(self, stage=None):
        self.train_dataset = DummyDataset()

        # DistributedSampler automatically retrieves world_size and rank
        # from the current distributed group.
        #
        # ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
        #
        # In PyTorch Lightning:
        # - The distributed environment is initialized by the Trainer.
        # - This sets up the process group with the correct world_size and rank.
        # - DistributedSampler then uses these values automatically.
        #
        # By not specifying num_replicas and rank, we allow DistributedSampler
        # to adapt to the current distributed setup, making our code more flexible.
        # This works seamlessly with PyTorch Lightning's managed distributed training.
        #
        # Note: This automatic retrieval only works correctly if the distributed
        # environment has been initialized, which Lightning ensures before calling setup().
        self.sampler = DistributedSampler(
            self.train_dataset,
            shuffle=False  # Ensure deterministic data order across processes for testing purposes
        )

    def train_dataloader(self):
        if self.dataloader is None:
            self.dataloader = DPAwareDataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                sampler=self.sampler,
                num_workers=2
            )
        return self.dataloader

    def state_dict(self):
        return {
            "dataloader_state": self.dataloader.state_dict() if self.dataloader else None,
        }

    def load_state_dict(self, state_dict):
        if self.dataloader and state_dict["dataloader_state"]:
            self.dataloader.load_state_dict(state_dict["dataloader_state"])

class DummyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 2)
        self.loss = torch.nn.CrossEntropyLoss()
        self.example_count = 0
        self.custom_global_step = 0

    def training_step(self, batch, batch_idx):
        time.sleep(10)
        x, y = batch
        y_hat = self.layer(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)

        self.example_count += len(x)
        self.custom_global_step = self.global_step

        if self.example_count % 100 == 0:
            print(f"GPU {self.global_rank}: Processed {self.example_count} examples, Global Step: {self.custom_global_step}")

        return loss

    def on_save_checkpoint(self, checkpoint):
        checkpoint['example_count'] = self.example_count
        checkpoint['custom_global_step'] = self.custom_global_step

    def on_load_checkpoint(self, checkpoint):
        self.example_count = checkpoint['example_count']
        self.custom_global_step = checkpoint['custom_global_step']

    def on_train_start(self):
        print(f"GPU {self.global_rank}: Starting/Resuming training. Example count: {self.example_count}, Global Step: {self.custom_global_step}")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

def main():
    parser = argparse.ArgumentParser(description="Train a dummy model with a unique run name.")
    parser.add_argument("run_name", type=str, help="Unique name for this training run")
    args = parser.parse_args()
    print("="*30, args.run_name, "="*30)

    log_dir = os.path.join("logs", args.run_name)
    os.makedirs(log_dir, exist_ok=True)

    model = DummyModel()
    dm = DummyDataModule(batch_size=4)

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(log_dir, 'checkpoints'),
        filename='model-{epoch:02d}-{train_loss:.2f}',
        save_top_k=1,
        verbose=True,
        monitor='train_loss',
        mode='min',
        every_n_epochs=1,
        save_last=True
    )

    trainer = L.Trainer(
        max_epochs=100,
        devices=4,
        accelerator='gpu',
        strategy='ddp',
        callbacks=[checkpoint_callback],
        plugins=[SLURMEnvironment(auto_requeue=True, requeue_signal=signal.SIGUSR1)],
        default_root_dir=log_dir,
        use_distributed_sampler=False,
    )

    trainer.fit(model, datamodule=dm, ckpt_path="last")

if __name__ == '__main__':
    main()

dummy_slurm.sh

#!/bin/bash
#SBATCH --job-name=auto_requeue_test
#SBATCH -C h100
#SBATCH -A ycy@h100
#SBATCH --nodes=1
#SBATCH --qos=qos_gpu_h100-dev
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --time=00:3:00
#SBATCH --signal=SIGUSR1@30  # Send signal 30 seconds before time limit

# Load any necessary modules or activate your environment here
# For example:
module purge
module load arch/h100
module load pytorch-gpu/py3/2.4.0
export PYTHONUSERBASE=$WORK/python_envs/worldmodel

echo "Starting job at $(date)"

# Generate a unique run name using the current date and time
RUN_NAME="run_${SLURM_JOB_ID}"

# Run the Python script with the unique run name
srun python dummy_model.py "$RUN_NAME"

echo "Job ended or requeued at $(date)"

Environment

Current environment ``` - PyTorch Lightning Version: 2.4.0 - PyTorch Version: 2.4.0 - Python version: 3.11.9 - How you installed Lightning(`conda`, `pip`, source): pip ```
arijit-hub commented 2 weeks ago

Hi @F-Barto, You do not need to specify ckpt_path="hpc". In the current setting of lightning (url) it always searches for the "hpc" path internally first, then if it doesn't find it, it takes the ckpt_path that you specify in your trainer.fit(..., ckpt_path=ckpt_path). So you shouldn't specify ckpt_path="hpc" and instead just do something like this:

## Option:1
## If you want to have an option for manual resuming
## Have a flag for resuming (args.resume)
# This will do autorequeue (or) resume with last (or) run from start.
trainer.fit(..., ckpt_path="/path/to/saved/checkpoints/last.ckpt" if args.resume else None) 

## Option: 2
## No manual resuming, just autoreque (or) run from start
trainer.fit(..., ckpt_path=None) 

Typically what happens is when you run a .sh file and your job is about to hit the wall-time, lightning automatically creates a temporary checkpoint (hpc_ckpt_*.ckpt) in the default_root_dir set by the user in trainer. Then when the job restarts, lightning automatically searches the default_root_dir folder, and if the hpc_ckpt_*.ckpt file is there, it would load it and resume training.

I hope this helps and it was clear enough. Let me know if there is something which is still confusing.