Open JanSellner opened 1 year ago
Digging a bit further into this: according to this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L254 in the SWA implementation, the backward pass should be skipped in the last SWA epoch. The variable _skip_backward
defined in https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L133 is responsible for skipping the backward pass, however, this does not seem to work because the backward pass of the optimization object is still called: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/optimization/automatic.py#L185. self._optimizer_step
is called even though closure._backward_fn
is None
.
Further, the variable skipped_backward
in the MixedPrecisionPlugin
class considers only the closure results but not closure._backward_fn
. Maybe this is the error? I.e. changing skipped_backward = closure_result is None
to skipped_backward = closure._backward_fn is None
would solve the problem.
So I think this might indeed be an issue with lightning and not with PyTorch but for some reason it only happens with the latest PyTorch version.
As a workaround, we can switch to manual optimization in the SWA epoch:
def on_train_epoch_start(self) -> None:
if self.current_epoch == self.trainer.max_epochs - 1:
# Workaround to always save the last epoch until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/4539)
self.trainer.check_val_every_n_epoch = 1
# Disable backward pass for SWA until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/17245)
self.automatic_optimization = False
on_train_epoch_start
Hi,
Did you test your workaround solution? Sorry I'm very new to lightning. I'm wondering where should I add the solution? Under https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L177? Thanks!
I added the on_train_epoch_start
method directly to my lightning module, e.g. as part of the BoringModel
in the initial example from above:
BoringModel
This is working so far. I guess it would also be possible to write your custom SWA callback which inherits from StochasticWeightAveraging
and overwrite the method there, but I have not tested it.
@JanSellner what PL version are you using?
2.0.1 at the time of creating this issue but also just reproduced with 2.0.2.
I can confirm the undesired behaviour, pl 2.0.2
Seems related to this issue regarding AMP on the torch forums. Maybe this helps?
I can confirm that this solution by @JanSellner stops the error being thrown, but not sure if this is the expected behaviour for SWA.
Digging a bit further into this: according to this line: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/callbacks/stochastic_weight_avg.py#L254 in the SWA implementation, the backward pass should be skipped in the last SWA epoch. The variable
_skip_backward
defined in https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py#L133 is responsible for skipping the backward pass, however, this does not seem to work because the backward pass of the optimization object is still called: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/loops/optimization/automatic.py#L185.self._optimizer_step
is called even thoughclosure._backward_fn
isNone
.Further, the variable
skipped_backward
in theMixedPrecisionPlugin
class considers only the closure results but notclosure._backward_fn
. Maybe this is the error? I.e. changingskipped_backward = closure_result is None
toskipped_backward = closure._backward_fn is None
would solve the problem.So I think this might indeed be an issue with lightning and not with PyTorch but for some reason it only happens with the latest PyTorch version.
As a workaround, we can switch to manual optimization in the SWA epoch:
def on_train_epoch_start(self) -> None: if self.current_epoch == self.trainer.max_epochs - 1: # Workaround to always save the last epoch until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/4539) self.trainer.check_val_every_n_epoch = 1 # Disable backward pass for SWA until the bug is fixed in lightning (https://github.com/Lightning-AI/lightning/issues/17245) self.automatic_optimization = False
This has been tagged as 2.0.x
, but there seems to be no fix for it there. When is this targeted to be solved?
I am facing the same error if I try to use accumulate_grad_batches
with value larger than 1
. This argument works only when automatic_optimization = True
So I downgraded lightning
to 1.9.4
and pytorch
version to 1.13.1
. It didn't help. I also tried with downgrading lightning
first, it didn't help either.
Any ideas how to use accumulate_grad_batches
when automatic_optimization = False
?
I know that we should call optimizer and manual_backward
in training_step
when automatic_optimization = False
.
UPDATE: setting precision='bf16'
instead of precision='16'
fixed the problem (latest lightning
and pytorch
).
Bug description
When SWA is used together with a model which has batch norm layers, the assertion
No inf checks were recorded for this optimizer.
is raised in the last epoch (=SWA epoch).This worked fine with torch<2.0 but I am not sure whether it is a torch or lightning issue.
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.1 #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): 2.0 #- Python version (e.g., 3.9): 3.10 #- OS (e.g., Linux): Ubuntu #- CUDA/cuDNN version: 10.8 #- GPU models and configuration: 3090 #- How you installed Lightning(`conda`, `pip`, source): pip #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response