Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.2k stars 3.37k forks source link

`configure_optimizers` in LightningCLI doesn't work in distributed mode #16489

Open stevenmanton opened 1 year ago

stevenmanton commented 1 year ago

Bug description

LightningCLI provides a way to override a model's configure_optimizers method. You can do this as a subclass or through the configuration. This works well on the CPU and if there's a single GPU, but it seems to fail for multiple GPUs. That is, it seems that the parameters provided in the configuration aren't actually used in distributed mode.

How to reproduce the bug

Because logging the optimizer is tricky in distributed mode, one easy way to test is to remove the configure_optimizers method from the model. Since LightningCLI.configure_optimizers overrides this method, you can actually do this. However, since it fails in distributed mode, you'll get an error.

For example, with the following script:

# dummy.py
from pytorch_lightning.demos.boring_classes import BoringModel, BoringDataModule
from pytorch_lightning.cli import LightningCLI

del BoringModel.configure_optimizers

if __name__ == "__main__":

    cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={
        "max_steps": 1,
        "accelerator": "gpu",
    })

You can run:

python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=1

and the script will complete. But if you run

python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=2

then you'll get an exception.

Error messages and logs

Traceback (most recent call last):
  File "src/product_dna/_antonstv/train/dummy.py", line 8, in <module>
    cli = LightningCLI(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 358, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/cli.py", line 670, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 113, in launch
    mp.start_processes(
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 139, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1029, in _run
    verify_loop_configurations(self)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py", line 41, in verify_loop_configurations
    __verify_train_val_loop_configuration(trainer, model)
  File "/home/antonstv/miniconda3/envs/pdna/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py", line 80, in __verify_train_val_loop_configuration
    raise MisconfigurationException(
lightning_fabric.utilities.exceptions.MisconfigurationException: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.

Environment

More info

This error took me a really long time to track down. Even if there's not an easy fix, it would be great to throw a warning.

cc @borda @carmocca @mauvilsa @justusschock @awaelchli

awaelchli commented 1 year ago

@stevenmanton I don't understand. Why do you think having this line of code is appropriate or expected?

del BoringModel.configure_optimizers

awaelchli commented 1 year ago

I think this is just a multiprocessing issue/limitation. You should be fine if you move the del BoringModel.configure_optimizers under the if name == main guard (again I don't understand what the purpose is). Alternatively, you can set --trainer.strategy=ddp (the default is ddp_spawn).

stevenmanton commented 1 year ago

Hi @awaelchli, maybe the test case is confusing the general issue. What I've observed is that any optimizer or learning rate scheduler that's specified in a YAML file works correctly when running on a single process (e.g. CPU or single GPU), but does NOT work in the distributed mode (e.g. multiple GPUs).

For example, say you define a model with an AdamW optimizer:

class MyModel(pl.LitModule):
    ...
    def configure_optimizers(self):
        return AdamW()

You can override the AdamW optimizer with, say, a Lamb optimizer with the YAML config:

# config.yaml
optimizer:
  class_path: torch_optimizer.Lamb

However, if you try to run with this config with multiple GPUs, your model won't run with the Lamb optimizer, but with the default AdamW optimizer.

For example:

# This will train a job with the Lamb optimizer:
python src/product_dna/_antonstv/train/dummy.py fit -c config.yaml --trainer.devices=1
# This will train a job with the AdamW optimizer:
python src/product_dna/_antonstv/train/dummy.py fit -c config.yaml --trainer.devices=2

What I was trying to show with the test case was that there's something about the logic with spawned processes that doesn't patch MyModel.configure_optimizers method with the method specified in the YAML config. But I'm not sure I've set up that test case correctly. I was trying to add print statements in the code to show what type of optimizer was being used, but because the processes are spawned I don't think that the stdout of the spawned processes makes it back up to the parent process.

I've observed the same issue with the learning rate scheduler.

Hopefully this clarifies what I've been observing.

awaelchli commented 1 year ago

Okay, this clarifies it. I did not even know that LightningCLI patches this method. I guess the "bug" makes sense.

The LightningCLI patches the boring model class only in the main process. The worker processes won't ever get the patched class. In general, we can observe the following thanks to multiprocessing:


import torch.multiprocessing as mp

class BoringModel:
    def foo(self):
        return "foo"

def worker(i):
    print("in worker", BoringModel().foo())

if __name__ == "__main__":
    # The patching only happens in the main process
    BoringModel.foo = lambda _: "bar"
    print("in main", BoringModel().foo())

    mp.spawn(worker, nprocs=2)

The output is

in main bar
in worker foo
in worker foo

I'm not sure what we can do about this in the LightningCLI. At the time the processes get spawned, the model has already been instantiated.

I suggest that the LightningCLI error out when a user requests to patch the method and the launcher is a multiprocessing launcher with the spawn start method, with the suggestion that the user should switch to strategy=ddp (non-spawn).

stevenmanton commented 1 year ago

I think this is where the patching happens.

I agree that an error message at the very least would be helpful. It took me a very long time to figure out what was going on, especially because the logging of spawned processes is not easy to do.

Thanks for looking into this! Also, thanks for helping to maintain an awesome project!

stevenmanton commented 1 year ago

Update/summary (using dummy.py defined above):

# This works:
python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=1
# This fails:
python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=2
# This works:
python src/product_dna/_antonstv/train/dummy.py fit --optimizer=Adam --trainer.devices=2 --trainer.strategy=ddp
awaelchli commented 1 year ago

@carmocca What do you think. I don't see another way. My suggestion is

I suggest that the LightningCLI error out when a user requests to patch the method and the launcher is a multiprocessing launcher with the spawn start method, with the suggestion that the user should switch to strategy=ddp (non-spawn).

carmocca commented 1 year ago

An error message would be helpful. However, I don't know if we can check whether spawn will be used reliably, particularly if "auto" is used.

We would need to swap this two lines to at least have a trainer reference: https://github.com/Lightning-AI/lightning/blob/fd61ed065a19b89f05b1d751ac1b9106cf68145c/src/lightning/pytorch/cli.py#L504-L505