Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.95k stars 3.34k forks source link

What's the intended way of resuming training on a SLURM cluster? #16639

Open RoiEXLab opened 1 year ago

RoiEXLab commented 1 year ago

📚 Documentation

Hi, I'm not sure if this is the intended type of issue for this category, but I thought trying doesn't hurt:

I'm trying to use lightning to train my model on a SLURM cluster due to the high memory requirements. For fairness, it only allows to train 48 hours at once, so I looked up the documentation on how to properly use checkpointing to resume where I left off as described here. The documentation states this:

Resume training state

If you don’t just want to load weights, but instead restore the full training, do the following:

model = LitModel()
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

So I only let it run for the maximum amount of epochs possible before the wall time would kick in. Lightning then automatically saves a checkpoint in the lightning_logs/version_21716264/checkpoints/ directory, where 21716264 is the id of the SLURM job that was used to run training. Then when I manually re-queue the job I pass the checkpoint created by the last run via CLI where it is passed to trainer.fit. So this seems to work at first glance, but there's something I noticed that seemed odd. When I do this the following warning is issued:

Restoring states from the checkpoint path at lightning_logs/version_21716264/checkpoints/epoch=13-step=188874.ckpt
/path/to/venv/lib64/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:346: UserWarning: The dirpath has changed from '/path/to/lightning_logs/version_21716264/checkpoints' to '/path/to/lightning_logs/version_21719253/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
  warnings.warn(

This seems to suggest that using my approach it "reverts back" to the best model for further training instead of using the last model, which makes training rather ineffective at later stages when validation loss does not always decrease in a couple of epochs.

I tried searching the documentation for a solution to this problem, however it doesn't seem to clearly explain what's the intended way to achieve this. The fact that it uses the SLURM batch id for the version folder seems to suggest, that by default lightning is SLURM aware, so I searched for the lightning documentation on SLURM. I stumbled upon #13773 where I realized a "hpc" keyword exists for the checkpoint, which is supposed to restore the latest "hpc" state. I tried using it, but I just get

ValueError: `.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found. Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`

So apparently there is some required prequisite for this? The "hpc" keyword is not mentioned on the cluster (advanced) page, and the trainer page only mentions this as part of the class API. Nowhere to be seen in samples, but the documentation states it "just works", when it clearly doesn't. Disclaimer at this point, I'm currently using version 1.8.1, but the changelog doesn't seem to suggest a bugfix in this direction, so I assume it still applies.

So at this point I'm confused. Do I need to store the checkpoint differently? Can't I use the automatic checkpoints for this? Or do I need to override the lightning log directory so it all happens in the same directory so the warning doesn't appear? If so, how? The trainer and SLURM doc pages don't mention anything in this direction. Do I need to define a custom ModelCheckpoint callback, or should I define a custom logger, just to use a fixed directory depending on the model? Is there a bug in lightning? I honestly have no idea what I'm supposed to do. The SLURM documentation also mentions auto-requeuing, but I assume that because of my setup bash doesn't pass the signal properly to python. I really hope that's not a precondition for this to work, because again the documentation does not mention anything like this.

Anyways, I feel like there is a straightforward solution to all of my problems the documentation is not telling me. I'm using very limited lightning functionality, so I'm naively assuming it'd just work out of the box:

dm = MyDataModule()
model = MyModel(learning_rate)
trainer = Trainer(
  accelerator='gpu' if gpu_count > 0 else None,
  devices=gpu_count or None,
  max_epochs=max_epochs,
  strategy=None if gpu_count <= 1 else DDPStrategy(find_unused_parameters=False, static_graph=True)
)
trainer.fit(model, dm, ckpt_path=checkpoint)

I'd be happy if you could point me in the right direction and add your suggestions to the official documentation so that future people won't have the same struggle as I did.

cc @borda @awaelchli

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

RoiEXLab commented 1 year ago

Bump

leoauri commented 10 months ago

I have essentially the same problem. It would be nice if someone could chime in who uses auto-resubmission in Slurm.

Does it actually work?

leoauri commented 9 months ago

I think my problem is actually a different one, or perhaps just a further source of confusion, but I'll note here in case related:

The slurm job id is used by the default logger. If the logger is defined explicitly, then by default it will number the versions consecutively, and dirpath will change when the job is requeued, so all that stuff mentioned in the error above won't be reloaded.

I worked around this by adding something like:

    # Set version number to Slurm job number
    array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
    if array_job_id is not None:
        array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
        job_id = f"{array_job_id}_{array_task_id}"
    else:
        job_id = os.environ["SLURM_JOB_ID"]

    if job_id is not None:
        job_id = int(job_id)

    trainer = pl.Trainer(
        logger=pl.loggers.TensorBoardLogger(
            "runs",
            name=RUN_NAME,
            version=job_id,
        ),
    ...
    )

Some kind of hint of this in the documentation would be useful.

heth27 commented 1 month ago

It would generally be useful to expand the documentation of expected behavior for restoring automatic checkpoints.