Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

Using Manual Optimisation #5108

Closed KamWithK closed 3 years ago

KamWithK commented 3 years ago

❓ Questions and Help

What is your question?

The documentation for manual optimisation is vague and doesn't provide any complete examples of how to use it correctly. Hence, I need to know what Lightning takes care off, and what I have to do within the training loop. To be more specific, do I need to zero gradients and call step on the learning rate scheduler myself?

Given that this is "manual" mode I'd be fine having to do it (and half expect I will), but what's extremely confusing is that the given examples seem to switch between stating/showing gradients being manually zeroed and not being touched at all... Take the current optimizer section example, it does not show anything being zeroed. On the other hand, the documentation for trainer's manual optimisation shows the gradients being explicitly zeroed. So which is it?

Furthermore, how do I access/step for the learning rate scheduler (or is that not something for me to handle here)?

What have you tried?

For a little context, what I'm trying to do is to port regular PyTorch GAN code into Lightning. The module dynamically selects whether to train the generator or discriminator at the start of each batch depending on the discriminator's loss. So if the loss is below some threshold it'll backpropagate & optimise for the discriminator, otherwise generator.

I previously in <=1.0.8 used automatic optimisation with a custom optimizer step function, however, that in all honesty was quite clunky and no longer works with accumulate_grad_batches (which we need as we're working with extremely large 3d data). Instead of this today I've written code to check whether this batch is for training discriminator or generator and based on that run self.manual_backward(loss, optimizer); optimizer.step(). I'm pleased that it runs, but can't seem to see any documentation which actually specifies whether this is enough to use the scheduler and accumulated gradient batches.

What's your environment?

Thanks so much for any help!

CarloLucibello commented 3 years ago

Also, if the Trainer's option has been deprecated in favour of a LightningModule's property #5011 https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.automatic_optimization the documentation should be made consistent.

It's also not clear if there is something else that we should worry about when doing manual optimization in a distributed setting

AliKarimi74 commented 3 years ago

I have exactly the same issue.

automatic_optimization is deprecated in v1.1.0, but there is not an alternative in documents. There is an automatic_optimization property in LightningModule, but it's not obvious how we can change it. Also, there are inconsistencies in the documentation about manually calling zero_grad. Nowhere also mentioned that schedulers step is handling by lightning or not.

blakedewey commented 3 years ago

+1 on this issue. Documentation is very unclear about how to work with this going forward.

KamWithK commented 3 years ago

automatic_optimization is deprecated in v1.1.0, but there is not an alternative in documents. There is an automatic_optimization property in LightningModule, but it's not obvious how we can change it. Also, there are inconsistencies in the documentation about manually calling zero_grad. Nowhere also mentioned that schedulers step is handling by lightning or not.

I was confused by this but managed to figure this nuance out. You have to define it as a property, so something like:

class SomeModule(LightningModule):
    @property
    def automatic_optimization(self) -> bool:
        return False

There's probably an easier way (I'd hope you can set it with a one-liner), but I haven't been able to figure it out quite yet.

KamWithK commented 3 years ago

That though still doesn't explain what is and isn't managed by Lightning itself once you've turned automatic optimisation off :(

blakedewey commented 3 years ago

Also, I have been getting ValueError: Your LightningModule defines 2 optimizers but training_step is missing the "optimizer_idx" argument. when using manual optimization via the LightningModule parameter.

KamWithK commented 3 years ago

Also, I have been getting ValueError: Your LightningModule defines 2 optimizers but training_step is missing the "optimizer_idx" argument. when using manual optimization via the LightningModule parameter.

Yeah we still need to pass it in (although it's largely useless for manual optimisation - from what I'm reading)

heng-yuwen commented 3 years ago

I am also confused about how to handle scheduler manually. From what I've seen, the scheduler is not working properly when using checkpoint.

tchaton commented 3 years ago

Hey everyone,

Possible to use it as follow

            def training_step(...):
                (opt_a, opt_b) = self.optimizers()
                loss_a = ...
                # automatically applies scaling, etc...
                self.manual_backward(loss_a, opt_a)
                opt_a.step()

Best, T.C

tchaton commented 3 years ago
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx, optimizer_idx):

            # emulate gans training
            opt_gen, opt_dis = self.optimizers()

            # Note: Be careful, don't log on the same key in self.log in both closure
            # as they will be aggregated together on epoch_end

            def compute_loss():
                x = batch[0]
                x = F.dropout(x, 0.1)
                predictions = self(x)
                predictions = F.dropout(predictions, 0.1)
                loss = self.loss(None, predictions)
                return loss

            def gen_closure():
                loss_gen = compute_loss()
                self.log("loss_gen", loss_gen, on_step=True, on_epoch=True)
                self.manual_backward(loss_gen, opt_gen)

            def dis_closure():
                loss_dis = compute_loss()
                self.log("loss_dis", loss_dis, on_step=True, on_epoch=True)
                self.manual_backward(loss_dis, opt_dis)

            # this will accumulate gradients for 2 batches and then call opt_gen.step()
            opt_gen.step(closure=gen_closure, make_optimizer_step=(batch_idx % 2 == 0), optim='sgd')

            # update discriminator every 4 baches
            # therefore, no gradient accumulation for discriminator
            if batch_idx % 4 == 0 :
                # Note: Set make_optimizer_step to True or it will use by default
                # Trainer(accumulate_grad_batches=x)
                opt_dis.step(closure=dis_closure, make_optimizer_step=True)

        def training_epoch_end(self, outputs) -> None:
            # outputs should be an array with an entry per optimizer
            assert len(outputs) == 2

        def configure_optimizers(self):
            optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1)
            optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
            return [optimizer_gen, optimizer_dis]

        @property
        def automatic_optimization(self) -> bool:
            return False

    model = TestModel()
    model.val_dataloader = None
    model.training_epoch_end = None

    limit_train_batches = 8
    trainer = Trainer(
        default_root_dir=tmpdir,
        limit_train_batches=limit_train_batches,
        limit_val_batches=2,
        max_epochs=1,
        log_every_n_steps=1,
        accumulate_grad_batches=2,
        enable_pl_optimizer=True,
    )

    trainer.fit(model)
showgood163 commented 3 years ago

Hi @tchaton ! I have some questions regarding to this manual mode.

  1. opt_dis.step(closure=dis_closure, make_optimizer_step=True), is this step funciton the one in pytorch? Does the closure here support a function with input parameters?

  2. What should I do if I have different losses for the same optimizer that needs to be optimized in one batch but with seperate steps? For example, a normal discriminator loss that need to be updated every step, and a discriminator regularization loss that need to be updated every 16 steps and after the step of that normal loss.

  3. Will there be an detailed elaboration of this manual mode?

pierresegonne commented 3 years ago

Thanks for the thread, it greatly helped me setting up correctly a manual optimisation procedure.

I'd like to point out that the property override

class SomeModule(LightningModule):
    @property
    def automatic_optimization(self) -> bool:
        return False

was not sufficient for me, I was required to also set automatic_optimization=False when instanciating my trainer with Trainer.from_argparse_args(..)