Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.12k stars 3.37k forks source link

Example of gradient clipping with manual optimization does not handle gradient unscaling properly #18089

Open function2-llx opened 1 year ago

function2-llx commented 1 year ago

đź“š Documentation

The doc of manual optimization give an example of gradient clipping (added by #16023):

from lightning.pytorch import LightningModule

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()

        # compute loss
        loss = self.compute_loss(batch)

        opt.zero_grad()
        self.manual_backward(loss)

        # clip gradients
        self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")

        opt.step()

However, it seems that this example does not handle gradient unscaling properly. The gradients should be unscaled when using mixed precision training before calling self.clip_gradients.

cc @carmocca @justusschock @awaelchli @borda

function2-llx commented 1 year ago

I'm not sure if this is a limitation or not, currently I actually find no simple way to achieve this.

For automatic optimization, gradient unscaling is performed right after the optimizer closure (training step, zero grad, backward) and before the gradient clipping (called in self._after_closure). https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/amp.py#L78-L84 https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/precision_plugin.py#L77-L87 https://github.com/Lightning-AI/lightning/blob/e9c42ed11f68aafc18fe64a26d87118d57a5743c/src/lightning/pytorch/plugins/precision/precision_plugin.py#L116-L125

However, for manual optimization, the calling order is:

  1. epoch_loop.manual_optimization.run()
  2. model.training_step()
  3. inside the training step, user manually backward the loss (with gradient scaling), and call optimizer.step().
  4. In optimizer.step(), the gradients are unscaled, but the gradient clipping for the unscaled gradients are disabled due to manual optimization (in _after_closure -> _clip_gradients).

Above all, there seems to be no space for the user to insert gradient unscaling in training_step, since it's always unscaled in optimizer.step(). On ther other hand, The user is also unable to clip gradients after the gradients are unscaled before the optimizer's actual step.

So here comes a question, why not just also allow automatic gradient clipping for manual optimization? If users are supposed to take care of gradient clipping, most of the time they just simply call self.clip_gradients for unscaled gradients just like automatic optimization; if they want to do extra stuffs, they can make it via configure_gradient_clipping.

kkoutini commented 1 year ago

Hi, did you manage to find the correct way to do the scaling before the cliping for manual optimization ?

function2-llx commented 1 year ago

@kkoutini No, I give up and use fabric instead.

manipopopo commented 2 months ago

We might consider calling clip_gradients within the on_before_optimizer_step.

from lightning.pytorch import LightningModule

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()

        # compute loss
        loss = self.compute_loss(batch)

        opt.zero_grad()
        self.manual_backward(loss)

        opt.step()

    # # Include optimizer_idx parameter when using Lightning 1.x.x
    # def on_before_optimizer_step(self, optimizer, optimizer_idx)
    def on_before_optimizer_step(self, optimizer):
        # clip gradients
        self.clip_gradients(optimizer, gradient_clip_val=0.5, gradient_clip_algorithm="norm")
Oktai15 commented 2 months ago

@manipopopo are you sure that on_before_optimizer_step will be run when automatic_optimization=False automatic_optimization=True?

UPD: I did a mistake, I meant automatic_optimization=False instead of automatic_optimization=True, message is updated.

manipopopo commented 2 months ago

are you sure that on_before_optimizer_step will be run when automatic_optimization=True?

When automatic_optimization=True, the currently recommended approach for gradient clipping, according to the documentation, is to use Trainer(gradient_clip_val=...). This should ensure that gradient clipping is handled within the automated optimization loop.

When automatic_optimization=False, calling opt.step() triggers the optimizer_step method of the MixedPrecision plugin. According to the implementation of MixedPrecision.optimizer_step

        if not _optimizer_handles_unscaling(optimizer) and not skip_unscaling:
            # Unscaling needs to be performed here in case we are going to apply gradient clipping.
            # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
            # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
            self.scaler.unscale_(optimizer)  # type: ignore[arg-type]

        self._after_closure(model, optimizer)

Our custom on_before_optimizer_step hook is invoked within self._afterclosure, which occurs after `self.scaler.unscale`. This ensures that gradient unscaling is properly handled before clipping, aligning with the expected behavior.

The current documentation states that on_before_optimizer_step is "Called before optimizer.step()". However, it's worth noting that the exact timing and behavior of these hooks may be subject to change in future versions of PyTorch Lightning.

Oktai15 commented 2 months ago

@manipopopo thank you for brilliant explanation. Could you clarify how to use workaround with clipping in on_before_optimizer_step if I have two optimizers (I use PTL >= 2.x.x)?

manipopopo commented 2 months ago

The on_before_optimizer_step method will receive different optimizers as its input parameter. If you want to apply different behaviors—such as using different gradient clip values for each optimizer—you can adjust the method accordingly:


    def training_step(self, batch, batch_idx):
        # ...

        lightning_optimizer_0, lightning_optimizer_1 = self.optimizers()

        # ...

        lightning_optimizer_0.step() # on_before_optimizer_step will receive lightning_optimizer_0.optimizer

        # ...

        lightning_optimizer_1.step() # on_before_optimizer_step will receive lightning_optimizer_1.optimizer

    def on_before_optimizer_step(self, optimizer):
        lightning_optimizer_0, lightning_optimizer_1 = self.optimizers()
        optimizer_0 = lightning_optimizer_0.optimizer
        optimizer_1 = lightning_optimizer_1.optimizer

        assert optimizer in {optimizer_0, optimizer_1}

        # Suppose we want to use different gradient_clip_val for different optimizers.
        if optimizer is optimizer_0:
            gradient_clip_val = 10.0
        else:
            gradient_clip_val = 20.0

        self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm="norm")

This template should work with the current version of PyTorch Lightning. However, the behavior of hooks and internal methods in PyTorch Lightning may change in future updates. To ensure your implementation remains robust across versions, it's advisable to write small unit tests to verify functionality.