Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.48k stars 3.3k forks source link

incorrect global_step with multiple optimizers and automatic_optimization=False #17958

Open jkyl opened 1 year ago

jkyl commented 1 year ago

Bug description

Hello,

I encountered a bug when training with automatic_optimization = False and two optimizers.

In summary: the global_step attribute of the trainer and the lightning module is tracking the total number of calls to optimizer.step() (in my case, two per training_step), rather than the total number of iterations of the dataloader.

This conflicts with the notion of step in arguments like log_every_n_steps and val_check_interval in the trainer. Case in point, if we call

self.log("global_step", self.global_step)

inside training_step, with CSVLogger, log_every_n_steps=10, and two optimizer.step()s per training_step, the CSV logs show:

global_step,epoch,step
20.0,0,9
40.0,0,19
60.0,0,29
80.0,0,39
100.0,0,49

Note how global_step conflicts with step, and in fact is twice the expected value, since we have two optimizers.

I have attached a complete code example that replicates the issue.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import pytorch_lightning as pl
import torch

from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import TensorDataset, IterableDataset, DataLoader

SEMVER = tuple(int(x) for x in pl.__version__.split("."))
assert SEMVER >= (2, 0, 3)

class LinearRegression(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.gamma = torch.nn.Parameter(torch.ones(()))
        self.beta = torch.nn.Parameter(torch.zeros(()))
        self.automatic_optimization = False

    def forward(self, x):
        return self.gamma * x + self.beta

    def configure_optimizers(self):
        gamma_opt = torch.optim.SGD([self.gamma], lr=1e-2)
        beta_opt = torch.optim.SGD([self.beta], lr=1e-2)
        return gamma_opt, beta_opt

    def training_step(self, batch, batch_idx):

        # Two optimizers.
        gamma_opt, beta_opt = self.optimizers()

        # Forward pass with loss.
        inputs, targets = batch
        predictions = self(inputs)
        loss = torch.nn.functional.mse_loss(predictions, targets)

        # Backprop through entire graph but only update gamma.
        gamma_opt.zero_grad()
        self.manual_backward(loss, retain_graph=True)
        gamma_opt.step()

        # Backprop through partial graph and only update beta.
        beta_opt.zero_grad()
        self.manual_backward(loss, inputs=[self.beta])
        beta_opt.step()

        # Log the global step.
        self.log("global_step_train", self.global_step)

    def validation_step(self, batch, batch_idx):

        # Forward pass with loss.
        inputs, targets = batch
        predictions = self(inputs)
        loss = torch.nn.functional.mse_loss(predictions, targets)

        # Log the global step.
        self.log("global_step_val", self.global_step)

class IterableTensorDataset(IterableDataset):

    def __init__(self, inputs, targets):
        self.inputs, self.targets = inputs, targets

    def __iter__(self):
        while True:
            i = torch.randint(self.inputs.shape[0], size=())
            yield self.inputs[i], self.targets[i]

def load_dataset(gamma=2.0, beta=-1.0, sigma=0.2):
    inputs = torch.linspace(-1, 1, 201)
    targets = gamma * inputs + beta
    targets += sigma * torch.randn_like(targets)
    indices = torch.randperm(inputs.shape[0])
    pivot = inputs.shape[0] // 2
    train_inds, val_inds = indices[:pivot], indices[pivot:]
    return (
        (inputs[train_inds], targets[train_inds]),
        (inputs[val_inds], targets[val_inds]))

def main():
    train_data, val_data = load_dataset()
    train_set = IterableTensorDataset(*train_data)
    val_set = TensorDataset(*val_data)
    train_loader = DataLoader(train_set, batch_size=4)
    test_loader = DataLoader(val_set, batch_size=1)
    trainer = pl.Trainer(
        log_every_n_steps=10,
        val_check_interval=100,
        logger=CSVLogger("./logs"),
        enable_progress_bar=False,
        max_steps=1000,
    )
    model = LinearRegression()
    trainer.fit(model, train_loader, test_loader)
    return (
        model.gamma.data.detach().cpu().item(),
        model.beta.data.detach().cpu().item())

if __name__ == "__main__":
    print(main())

Error messages and logs

global_step_train,epoch,step,global_step_val
20.0,0,9,
40.0,0,19,
60.0,0,29,
80.0,0,39,
100.0,0,49,
120.0,0,59,
140.0,0,69,
160.0,0,79,
180.0,0,89,
200.0,0,99,
,0,99,200.0
220.0,0,109,
240.0,0,119,
260.0,0,129,
280.0,0,139,
300.0,0,149,
320.0,0,159,
340.0,0,169,
360.0,0,179,
380.0,0,189,
400.0,0,199,
,0,199,400.0
420.0,0,209,
440.0,0,219,
460.0,0,229,
480.0,0,239,
500.0,0,249,
520.0,0,259,
540.0,0,269,
560.0,0,279,
580.0,0,289,
600.0,0,299,
,0,299,600.0
620.0,0,309,
640.0,0,319,
660.0,0,329,
680.0,0,339,
700.0,0,349,
720.0,0,359,
740.0,0,369,
760.0,0,379,
780.0,0,389,
800.0,0,399,
,0,399,800.0
820.0,0,409,
840.0,0,419,
860.0,0,429,
880.0,0,439,
900.0,0,449,
920.0,0,459,
940.0,0,469,
960.0,0,479,
980.0,0,489,
1000.0,0,499,
,0,499,1000.0

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.4 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - async-timeout: 4.0.2 - attrs: 23.1.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - filelock: 3.12.2 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - idna: 3.4 - jinja2: 3.1.2 - lightning-utilities: 0.9.0 - markupsafe: 2.1.3 - mpmath: 1.3.0 - multidict: 6.0.4 - networkx: 3.1 - numpy: 1.25.0 - packaging: 23.1 - pip: 23.0.1 - pytorch-lightning: 2.0.4 - pyyaml: 6.0 - requests: 2.31.0 - setuptools: 67.6.0 - sympy: 1.12 - torch: 2.0.1 - torchmetrics: 0.11.4 - tqdm: 4.65.0 - typing-extensions: 4.7.0 - urllib3: 2.0.3 - wheel: 0.38.4 - yarl: 1.9.2 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.10.11 - release: 20.6.0 - version: Darwin Kernel Version 20.6.0: Thu Mar 9 20:39:26 PST 2023; root:xnu-7195.141.49.700.6~1/RELEASE_X86_64

More info

If this is the intended behavior, it should be reconciled with the trainer's notion of step. Arguments like log_every_n_steps and val_check_interval use a different definition of step.

jkyl commented 1 year ago

I corrected a mistake in the replication script (self.manual_backward versus loss.backward). The output is the same in both cases.

jkyl commented 1 year ago

It should also be noted that the max_steps=1000 argument to the trainer depends on global_step, which you can tell by the fact that the script terminates after 500 calls to training_step, even though max_steps was set to 1000. This is contrary to the definition of step used by log_every_n_steps and val_check_interval.

jkyl commented 1 year ago

The source of this behavior starts with the fact that trainer.global_step refers to the global_step property of training_epoch_loop.

In turn, that property derives its result from the optim_step_progress attribute of the _ManualOptimization loop object, whose total.completed attribute is incremented in _ManualOptimization._on_after_step.

Ultimately, _ManualOptimization._on_after_step is called via all of the LightningOptimizers created by the lightning module here. All optimizers are injected with the method here.

One possible fix would be to inject only one of the optimizers with the total.completed incrementing behavior, rather than all.

jkyl commented 1 year ago

Why this matters:

zhchz commented 1 year ago

+1 met the same thing

jkyl commented 1 year ago

Thought I'd provide a little more detail on my use case since other people have encountered this.

I'm training a GAN with multiple discriminator steps per generator step. My training step looks like this:

def training_step(self, batch, batch_idx):
    if batch_idx % self.n_critic == 0:
        self.update_generator_and_discriminator(batch)
    else:
        self.update_discriminator_only(batch)

This is more efficient than only updating one of the networks each iteration, because it allows one to re-use the generator outputs for the discriminator update. But, it also means that update_generator_and_discriminator makes two calls to optimizer.step.

As a workaround to this bug, I subclassed the trainer, like this:

class MyTrainer(pl.Trainer):

    def __init__(self, *, n_critic: int, **kwargs):
        super().__init__(**kwargs)
        self.n_critic = n_critic

    @property
    def global_step(self) -> int:
        return convert_global_step_to_current_iter(super().global_step, self.n_critic)

And I also implemented the following method:

def convert_global_step_to_current_iter(step: int, nc: int) -> int:
    return int(step * nc / (nc + 1))

This lets my callbacks run at the correct frequency, but is not a general solution. This only applies to the case where every n_critic number of steps, global_step is incremented by 2, and every other step it's incremented by 1.

leng-yue commented 1 year ago

This is a common issue for GAN and we should take a look.

jkyl commented 12 months ago

Thanks for the response! I'm happy to help with a PR, if you or anyone else has guidance for a way forward.

jkyl commented 11 months ago

hello, any triage or advice for this?

haihua commented 11 months ago

Hello, I have just just tried to use lightning (2.0.6), I observed my global_step is also out of sync with actual steps, which is also reflected on the tensorboard, making learning rate unchanged with more training in my case:

trainer = pl.Trainer(devices=2, accelerator='gpu', strategy='ddp', max_epochs=EPOCHS, logger=True, log_every_n_steps=50, check_val_every_n_epoch=1, callbacks=checkpoint_callback, accumulate_grad_batches=16,) What trainer sets have overwritten my nemo configure in what follows:

trainer: devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: 200000 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: bf16 # 16, 32, or bf16 log_every_n_steps: 100 # Interval of logging. enable_progress_bar: True resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs sync_batchnorm: true enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager benchmark: false # needs to be false for models with variable-length speech input as it slows down training So far, my training progress is like: Epoch 254: 23%|██▎ | 201/883 [02:12<07:29, 1.52it/s, v_num=44, loss_step=49.80, loss_epoch=51.10] From this, I think I have already run 253*883 steps However, what my tensorboard is displaying:

image

It told me only 14k global steps has been run, obviously wrong.

This is annoying since lightning changed the learning rate according to global_steps, and now global steps are mis-calculated Besides, it cannot step normally. For instance, I set my max_steps as 200000, and actual running steps are already over 223399, and it is not stopped as expected.

image
yw0nam commented 10 months ago

When i set self.automatic_optimization = False, I got same error. it caused by optimizer.step() that increase self.global_step 1 for each called.

My observation is as follow.

When i using 2 optimizer, i got 2 times larger global step then actual step. When i using 3 optimizer, i got 3 times larger global step then actual step.

So, in this case, we need to figure out how to handle global_step increase when called optimizer.step() for proper training.

leng-yue commented 10 months ago

Maybe considering change the definition of global step as the number of time "training_step" is called. But this will be a breaking change... Adding a flag to open it will be better.

Fitree commented 9 months ago

I've encountered the same problem and solved this problem as below. However, I'm not sure this method does not makes another problem. If someone finds possible edge case about my logic, please commnet below.

[Background]

class _ManualOptimization(_Loop):
    """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
    entirely in the :meth:`~lightning.pytorch.core.module.LightningModule.training_step` and therefore the user is
    responsible for back-propagating gradients and making calls to the optimizers.

    This loop is a trivial case because it performs only a single iteration (calling directly into the module's
    :meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).

    """

    output_result_cls = ManualResult

    def __init__(self, trainer: "pl.Trainer") -> None:
        super().__init__(trainer)
        # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
        # `_OptimizationProgress`
        self.optim_step_progress = _Progress.from_defaults(_ReadyCompletedTracker)

        self._output: _OUTPUTS_TYPE = {}

    def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        self.on_run_start()
        with suppress(StopIteration):  # no loop to break at this level
            self.advance(kwargs)
        self._restarting = False
        return self.on_run_end()

    def on_run_start(self) -> None:
        # inject logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = self._on_before_step
            lightning_optimizer._on_after_step = self._on_after_step

    def advance(self, kwargs: OrderedDict) -> None:
        """Performs the training step for manual optimization.

        Args:
            kwargs: The kwargs passed down to the hooks.

        """
        trainer = self.trainer

        # manually capture logged metrics
        training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
        del kwargs  # release the batch from memory
        self.trainer.strategy.post_training_step()
        result = self.output_result_cls.from_training_step_output(training_step_output)

        self._output = result.asdict()

    def on_run_end(self) -> _OUTPUTS_TYPE:
        """Returns the result of this loop, i.e., the post-processed outputs from the training step."""
        output, self._output = self._output, {}  # free memory
        # reset logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = do_nothing_closure
            lightning_optimizer._on_after_step = do_nothing_closure
        return output

    def _on_before_step(self) -> None:
        self.optim_step_progress.increment_ready()
        self.trainer.profiler.start("optimizer_step")

    def _on_after_step(self) -> None:
        self.trainer.profiler.stop("optimizer_step")
        self.optim_step_progress.increment_completed()

[Solution]

...
    def training_step(self, batch, batch_idx):
        gamma_opt, beta_opt = self.optimizers()
        beta_opt._on_before_step = lambda : self.trainer.profiler.start("optimizer_step")
        beta_opt._on_after_step = lambda : self.trainer.profiler.stop("optimizer_step")
        ...

[Suggestion for the PytorchLightning]

...
def configure_optimziers():
    opt1 = Adam(...)
    opt2 = Adam(...)
    return (
        {"optimizer": opt1},
        {"optimizer": opt2, "do_not_count_global_step": True},
    )
yzslab commented 9 months ago

I've encountered the same problem and solved this problem as below. However, I'm not sure this method does not makes another problem. If someone finds possible edge case about my logic, please commnet below.

[Background]

  • trainer's global step is alias of trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed
  • When you call trainer.fit with manual optimization, actual training logic (your lightningmodule.training_step implementation) execuated at the trainer.fit_loop.epoch_roop.manual_optimization.run(). (@jkyl mentiond same thing above)
  • At trainer.fit_loop.epoch_roop.manual_optimization.run(), three methods will be called.

    • trainer.fit_loop.epoch_roop.manual_optimization.on_run_start
    • trainer.fit_loop.epoch_roop.manual_optimization.advance
    • trainer.fit_loop.epoch_roop.manual_optimization.on_run_end
  • At trainer.fit_loop.epoch_roop.manual_optimization.on_run_start, override all optimizer's _on_before_step and _on_after_step so that each optimizer's step increases trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed by one.
  • Below code is manual optimization class. self.optim_step_progress.increment_completed() method increases trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed.
class _ManualOptimization(_Loop):
    """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
    entirely in the :meth:`~lightning.pytorch.core.module.LightningModule.training_step` and therefore the user is
    responsible for back-propagating gradients and making calls to the optimizers.

    This loop is a trivial case because it performs only a single iteration (calling directly into the module's
    :meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).

    """

    output_result_cls = ManualResult

    def __init__(self, trainer: "pl.Trainer") -> None:
        super().__init__(trainer)
        # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
        # `_OptimizationProgress`
        self.optim_step_progress = _Progress.from_defaults(_ReadyCompletedTracker)

        self._output: _OUTPUTS_TYPE = {}

    def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        self.on_run_start()
        with suppress(StopIteration):  # no loop to break at this level
            self.advance(kwargs)
        self._restarting = False
        return self.on_run_end()

    def on_run_start(self) -> None:
        # inject logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = self._on_before_step
            lightning_optimizer._on_after_step = self._on_after_step

    def advance(self, kwargs: OrderedDict) -> None:
        """Performs the training step for manual optimization.

        Args:
            kwargs: The kwargs passed down to the hooks.

        """
        trainer = self.trainer

        # manually capture logged metrics
        training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
        del kwargs  # release the batch from memory
        self.trainer.strategy.post_training_step()
        result = self.output_result_cls.from_training_step_output(training_step_output)

        self._output = result.asdict()

    def on_run_end(self) -> _OUTPUTS_TYPE:
        """Returns the result of this loop, i.e., the post-processed outputs from the training step."""
        output, self._output = self._output, {}  # free memory
        # reset logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = do_nothing_closure
            lightning_optimizer._on_after_step = do_nothing_closure
        return output

    def _on_before_step(self) -> None:
        self.optim_step_progress.increment_ready()
        self.trainer.profiler.start("optimizer_step")

    def _on_after_step(self) -> None:
        self.trainer.profiler.stop("optimizer_step")
        self.optim_step_progress.increment_completed()

[Solution]

  • Since manual optimization logic overrides optimizer's hook before "training_step" called, we can re-override the optimizer's hook at the top of the "training_step".
  • Example:
...
    def training_step(self, batch, batch_idx):
        gamma_opt, beta_opt = self.optimizers()
        beta_opt._on_before_step = lambda : self.trainer.profiler.start("optimizer_step")
        beta_opt._on_before_step = lambda : self.trainer.profiler.stop("optimizer_step")
        ...

[Suggestion for the PytorchLightning]

  • If this method seems safe, we could contribute by PR in two ways.

    1. Add this method as a guide in the Pytorch Lightning documentation (somewhere like PYTORCH LIGHTNING BASIC GAN TUTORIAL)
    2. Or, we can make Lightningmodule's configure_optimziers interface to suppert options like this
...
def configure_optimziers():
    opt1 = Adam(...)
    opt2 = Adam(...)
    return (
        {"optimizer": opt1},
        {"optimizer": opt2, "do_not_count_global_step": True},
    )

Thanks for your solution, it works! But there is a typo in your code, the beta_opt._on_before_step = lambda : self.trainer.profiler.stop("optimizer_step") should be _on_after_step.

Fitree commented 9 months ago

@yzslab Great! Also, thanks for your commen. Typo fixed :)

askerlee commented 6 months ago

Thanks @Fitree for the neat fix! Do we need to update beta_opt._on_before_step and beta_opt._on_after_step at each step, or only the first step? Thanks.

Anner-deJong commented 4 months ago

Just ran into this as well, thanks @yzslab for the quick fix. Considering this is still a problem while this github issue looks like going stale, I'll have a stab at getting a PR in

@askerlee yes, the _on_before_step and _on_after_step functions get reassigned for each step, so you'll have to overwrite them in each step

Separately, in the meantime if anybody needs a quick fix for any number of optimizers, update to this:

for i, opt in enumerate(self.optimizers()):
    opt.zero_grad()
    if i+1 < len(self.optimizers()):
        opt._on_before_step = lambda : self.trainer.profiler.start("optimizer_step")
        opt._on_after_step = lambda : self.trainer.profiler.stop("optimizer_step")
Anner-deJong commented 4 months ago

(@ repo owners feel free to assign the issue to me)

jkyl commented 3 months ago

Thank you for your PR @Anner-deJong! Hope it is merged soon