Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Pass second-order closure to all optimizers (not just LBFGS) #2785

Closed awaelchli closed 4 years ago

awaelchli commented 4 years ago

🚀 Feature

I could be wrong, but I noticed the following in the code of lightning module's optimizer_step

        if on_tpu:
            xm.optimizer_step(optimizer)
        elif using_native_amp:
            self.trainer.scaler.step(optimizer)
        elif using_lbfgs:
            optimizer.step(second_order_closure)
        else:
            optimizer.step()

If someone uses a custom optimizer that needs the closure returning the loss multiple times, it won't work.

Pitch

Since all classes that inherit from torch.optim.Optimizer have the step method accept a closure (even if they don't need it), we could just do

        if on_tpu:
            xm.optimizer_step(optimizer)
        elif using_native_amp:
            self.trainer.scaler.step(optimizer)
        # elif using_lbfgs:
        #     optimizer.step(second_order_closure)
        else:
            optimizer.step(second_order_closure)

and drop the "using_lbfgs" argument?

Alternatives

The user has to override the optimizer_step themself.

Borda commented 4 years ago

it sounds as a good way to go to me... cc: @williamFalcon

williamFalcon commented 4 years ago

we had that before but removed it for some reason. I can't remember it though. But i originally agreed.

Let's do this after 0.9.0 since we need to dig back to why it changed

edenlightning commented 4 years ago

@awaelchli want to take a look at this now? or better post v1?

awaelchli commented 4 years ago

@williamFalcon LBFGS not being compatible with native amp is the only exception I found in the code. maybe you mean that? @edenlightning I don't think it's an essential for v1.0, but if you wish to have the final api without this argument "using_lbfgs" for optimizer step, then I could send a PR.