Open simonpokorny opened 1 year ago
Can you provide more details? This example shows it working
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.automatic_optimization = False
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
print(self.trainer.global_step)
opt = self.optimizers()
opt.zero_grad()
loss = self(batch).sum()
loss.backward()
opt.step()
return loss.detach()
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2,
limit_train_batches=3,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()
Thanks, for sure.
I used your example with the custom optimizer (see below) and the global step is not increasing ..
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from classifiers.sam import SAM
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
self.labels = torch.randint(low=0, high=2, size=(size,))
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(32, 2)
self.automatic_optimization = False
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
data, labels = batch
opt = self.optimizers()
# first forward-backward pass
pred = self.model(data)
loss_1 = self.loss_fn(pred, labels)
self.manual_backward(loss_1)
opt.first_step(zero_grad=True)
# second forward-backward pass
pred = self.model(data)
loss_2 = self.loss_fn(pred, labels)
self.manual_backward(loss_2)
opt.second_step(zero_grad=True)
print(self.trainer.global_step)
return loss_2
def configure_optimizers(self):
base_optimizer = torch.optim.Adam
optimizer = SAM(self.parameters(), base_optimizer, rho=1, adaptive=True, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
def run():
train_data = DataLoader(RandomDataset(size=32, length=64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=2,
limit_train_batches=3,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()
Where the SAM optimizer is from https://github.com/davda54/sam.
class SAM(torch.optim.Optimizer):
"""
SAM Optimizer
https://github.com/davda54/sam
"""
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None: continue
self.state[p]["old_p"] = p.data.clone()
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
if zero_grad: self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.first_step(zero_grad=True)
closure()
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][
0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups
Okay. This happens because we assume there will be an optimizer.step()
call, which is what we wrap to inject the strategy-specific logic (e.g. DDP): https://github.com/Lightning-AI/lightning/blob/50331e08e111d6b9ebb25a21a86b7170b46c5f1f/src/pytorch_lightning/core/optimizer.py#L101-L173
The call chain is LightningModule.training_step() -> _LightningOptimizer.step() -> Strategy.optimizer_step() -> PrecisionPlugin.optimizer_step() -> Optimizer.step()
Your use of the SAM
optimizer violates this assumption, as you are calling two different step methods ({first,second}_step
) which are not wrapped like .step()
. It's not clear to me if you would expect to increase the global_step
count after each or if only after the second_step()
.
To resolve this, we would need some mechanism to indicate what method we should wrap. cc @awaelchli @justusschock in case they have suggestions in this regard.
Another example of this issue is in https://github.com/ludwigwinkler/JaxLightning/blob/8585863be636152b6adba77a0436ff7509fb92f3/BNN/JaxLightning_BNN.py#L215-L217 (cc @ludwigwinkler) which also suffers from this issue because the Jax optimizer uses .update()
instead of .step()
The SAM
optimizer training step can be rewrite to classical form with a single closure-based step function
def training_step(self, batch, batch_idx):
data, labels = batch
opt = self.optimizers()
def closure():
loss = self.loss_fn(self.model(data), labels)
loss.backward()
return loss
loss = self.loss_fn(self.model(data), labels)
loss.backward()
opt.step(closure)
opt.zero_grad()
print(self.trainer.global_step)
return loss
After that , pl is able to wrap call .step()
and self.trainer.global_step is increasing.
If I understand this here correctly, my proposal is to have a check in our LightningOptimizer wrapper that the step method is available. If not, raise an error suggesting the user should do optimizer.step = optimizer.real_step_method
in e.g. the configure_optimizers
hook to have it supported in Lightning. IMO this is the easiest and doesn't require new APIs.
The suggestion
have a check in our LightningOptimizer wrapper that the step method is available
is not foolproof: the SAM optimizer shown above offers first_step
, second_step
, and step
. If the user didn't know about this limitation and called first_step
and second_step
, they would face this issue but such check wouldn't trigger because the Optimizer also defines a step
.
But I don't have a better suggestion that doesn't involve a complex solution such as wrapping all optimizer methods and checking if parameters changed
Bug description
I turned off the automatic optimisation, because I am using SAM optimizer (https://github.com/davda54/sam). After that, the global_step of the trainer is not updating each train step, therefore checkpointcallback are not call even though it is pass to trainer ..
used callback :
pl.callbacks.ModelCheckpoint save_weights_only=True, save_top_k=3, monitor="val_acc", mode="max", save_on_train_epoch_end=False)
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): - PyTorch Lightning Version 1.8.4: - PyTorch Version 1.13: - Python version 3.9: ```More info
No response
cc @tchaton @justusschock @awaelchli @borda @carmocca