Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.47k stars 3.39k forks source link

What is the correct way to set accumulate_grad_batches more than 1 and also a customized optimizer_step? #5054

Closed vincentzlt closed 3 years ago

vincentzlt commented 3 years ago

I had accumulate_grad_batches more than 1 and also a customized optimizer_step, which works well in 1.0.8.

In the latest release, I got an error raised in:

https://github.com/PyTorchLightning/pytorch-lightning/blob/02152c17299eaacb26748a858d1a3545b419b93a/pytorch_lightning/trainer/configuration_validator.py#L92

What is the best way to do it? I can not find any documentation on it.

vincentzlt commented 3 years ago

anyone can help me out? how to properly customize optimizer_step and have accumulate_grad_batches at the same time?

KamWithK commented 3 years ago

I'm also in the same scenario. I recently spent a lot of time converting regular PyTorch GAN into Lightning. The difficult part was keeping automatic optimisation. Quite a few functions like the optimizer_step had to be overridden so we could dynamically switch between training the generator and discriminator based on their current losses.

The accumulated gradient batches is a necessary feature for us, so it's annoying that 1.1.0 just throws this error. The docs are also unclear on whether we currently actually have accumulating gradients for manual optimisation (stable and future both just say it's coming in 1.1.0 which is now out...).

Would be absolutely amazing if someone can clarify on accumulating gradients works with manual optimisation and whether there's a way to override the optimizer step without losing this (important) functionality. Maybe just turn this into a warning instead of an error?

umbertov commented 3 years ago

@vincentzlt @KamWithK I think something like this would work?

if batch_idx % grad_acc_steps == 0:
    # do the thing
rohitgr7 commented 3 years ago

cc @tchaton

KamWithK commented 3 years ago

@vincentzlt @KamWithK I think something like this would work?

if batch_idx % grad_acc_steps == 0:
    # do the thing

We don't really want to be doing it manually. It's something the Lightning library handles for us. When considering that it handles it in manual mode, it doesn't really make sense to not be able to use it in automatic mode just because we've overridden the optimizer's step function right.

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Riccorl commented 3 years ago

Is there a solution for this (except implementing manual optimization)?

Orlllem commented 3 years ago

I was able to get this working by only changing the optimizer_step. It seems to have worked. It would be nice to check if this implementation could crash in general manner.

# learning rate warm-up
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
    on_tpu=False, using_native_amp=False, using_lbfgs=False,):
    # skip the first 500 steps
    if self.trainer.global_step < self.hparams["warmup_steps"]:
        lr_scale = min(1., float(self.trainer.global_step + 1) / self.hparams["warmup_steps"].)
        for pg in optimizer.param_groups:
            pg['lr'] = lr_scale * self.hparams["lr"]

    # update params
    if (batch_idx + 1) % self.hparams["accum_grads"] == 0:
      optimizer.step(closure=optimizer_closure)
chadHGY commented 3 years ago

Just encounter similar problem recently. I found that there's a new merge into main branch: https://github.com/PyTorchLightning/pytorch-lightning/pull/7980/files

Specifically, you can manually change the configuration_validator.py if can't upgrade to the latest version. https://github.com/DavidMChan/pytorch-lightning/blob/2a76b31a49e1f18663f1c239803a898765d6abde/pytorch_lightning/trainer/configuration_validator.py#L86