Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Wrap optimizer in LightningOptimizer object for manual optimization #4663

Closed edenlightning closed 3 years ago

edenlightning commented 3 years ago

Currently users are overriding optimizer_step in Lightning model, and therefore lose support for tpu, amp, etc... By wrapping their optimizers, we can make sure that when they call optimizer.step, everything is done properly.

It allows us to remove the need for manual_backward / manual_optimizer_step making API cleaner and reduce possible bugs.

new API for manual_optimization, but it also works for automatic optimization.

    Example::

        def training_step(...):
            (opt_a, opt_b) = self.optimizers()
            loss_a = ...

            # automatically applies scaling, toggle model, accumulated_gradients, tpu, zero_grad, etc...
            opt_a.backward(loss_a)
            opt_a.step()

    Example::

        def training_step(...):
            (opt_a, opt_b) = self.optimizers()
            loss_a = ...

            # automatically applies scaling, etc...
            def closure_a():
                loss_a = ...
                opt_a.backward(loss)

            opt_a.step(closure=closure_a)

    Example::

       For a GAN

        def training_step(self, batch, batch_idx):

            # emulate gans training
            opt_gen, opt_dis = self.optimizers()

            # Note: Be careful, don't log on the same key in self.log in both closure
            # as they will be aggregated together on epoch_end

            def gen_closure():
                ... forward and compute loss for generator
                loss_gen = ...
                self.log("loss_gen", loss_gen, on_step=True, on_epoch=True)
                opt_gen.backward(loss_gen)

            def dis_closure():
                ... forward and compute loss for discriminator
                loss_dis = ...
                self.log("loss_dis", loss_dis, on_step=True, on_epoch=True)
                opt_dis.backward(loss_dis)

            # this will accumulate gradients for 2 batches and then call opt_gen.step()
            opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0)

            # update discriminator every 4 batches
            # therefore, no gradient accumulation for discriminator
            if batch_idx % 4 == 0 :
                # Note: Set make_optimizer_step to True or it will use by default
                # Trainer(accumulate_grad_batches=x)
                opt_dis.step(closure=optimizer_closure, make_optimizer_step=True)

here is how you would do it with manual optimization

        def training_step(...):
            (opt_a, opt_b) = self.optimizers()
            loss = ...
            # automatically applies scaling, etc...

            self.manual_backward(loss, opt_a)

            # This will use accumulate gradients for `accumulate_grad_batches` batches
            # and then run opt_a.step()
            self.manual_optimizer_step(opt_a)

Here is how people would do it before by overriding optimizer_step

            def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                               optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
                # update generator opt every 2 steps
                if optimizer_idx == 0:
                    if batch_idx % 2 == 0 :
                        optimizer.step(closure=optimizer_closure)
                        optimizer.zero_grad()

                # update discriminator opt every 4 steps
                if optimizer_idx == 1:
                    if batch_idx % 4 == 0 :
                        optimizer.step(closure=optimizer_closure)
                        optimizer.zero_grad()

There is a flag enable_pl_optimizer to switch it off.

It will allow us to move AMP logic from accelerators to LightningOptimizer, cleaning out the internal code.

Finally, we are working on some DDPPlugin and we need the step function to be wrapped.

It is still an idea and we didn't decided if we will integrate it. I am really keen to hear your feedbacks @carmocca @awaelchli @justusschock @williamFalcon :)

edenlightning commented 3 years ago

@awaelchli