davda54 / sam

SAM: Sharpness-Aware Minimization (PyTorch)
MIT License
1.77k stars 196 forks source link

Sharpness Aware Minimization is a normal (Closure) base optimizer #8

Closed rmcavoy closed 3 years ago

rmcavoy commented 3 years ago

In your implementation of SAM, you split the optimization into first and second step and the code gives an error if "step" is called because it is not a normal optimizer. This doesn't necessarily have to be true because of the optional argument to step called "closure" which is a feature of optimizers like LBFGS that evaluate the loss and gradients at intermediate states (https://pytorch.org/docs/stable/optim.html?highlight=bfgs#torch.optim.LBFGS). If we use closure, the training routine for SAM looks like

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)

for input, output in data:
     def closure():
           optimizer.zero_grad()
           loss = loss_function(output, model(input))
           loss.backward()
           return loss

    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.step(closure)
    optimizer.zero_grad()

where step(closure) is defined as,

    def step(self, closure=None):
          assert closure is not None, "Sharpness Aware Minimization requires closure but it was not provided"

          self.first_step()

          closure()

          self.second_step()

This change would bring it more in compliance with the standard pytorch optimizer implementation and make it easier to adopt since some fraction of people already implement closure based optimizers. At the very least, it gives people the option of which to adopt.

davda54 commented 3 years ago

Thank you very much for this suggestion! :) I was not aware there is such a standardized pattern. I've updated the code, so that one can choose which from both options.