Open hankyul2 opened 2 years ago
Hi, as stated in https://github.com/PyTorchLightning/pytorch-lightning/issues/8100#issuecomment-867819299 this can be done by replacing just one part of our SWA.
@justusschock thank you for your reply.
Is there a way to use ema using swa?
I checked the link. For me, swa
seems to update lr scheduler and model weights together. Am I right?
@hankyul2, I believe this is how it would be implemented:
from pytorch_lightning.callbacks import StochasticWeightAveraging
class EMA_Callback(StochasticWeightAveraging):
def __init__(self, decay=0.9999):
super().__init__()
self.decay = decay
def avg_fn (
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
) -> torch.FloatTensor:
e = averaged_model_parameter
m = model_parameter
return self.decay * e + (1. - self.decay) * m
Let me know if you have any problems. I just learned about SWA because of your issue!
@mathemusician thank you.
@hankyul2 I don't think that @mathemusician solution is equivalent. avg_fn
is called once per epoch while EMA updates happens every training step. I don't think that EMA can be implemented with SWA callback,
@hal-314 yeah. I think so.
@hal-314 @mathemusician
I have implemented EMA Callback with simple functionality.
I think much more options should be added. For example, save_weight, ema_step_period, etc.
If you find it helpful, please let me know. Then I will implement it more. If you don't, close this issue or leave any comments.
from copy import deepcopy
import torch
from pytorch_lightning import Callback
class EMACallback(Callback):
def __init__(self, decay=0.995):
self.decay = decay
self.module_pair_list = []
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def forward_wrapper(module, org, ema):
def forward(*args, **kwargs):
return org(*args, **kwargs) if module.training else ema(*args, **kwargs)
return forward
modules = list(filter(lambda x: len(list(x[1].parameters())) > 0, pl_module.named_children()))
for name, module in modules:
ema_module = deepcopy(module)
self.module_pair_list.append((ema_module, module))
pl_module.add_module(f'EMA_{name}', ema_module)
module.forward_bc = module.forward
module.forward = forward_wrapper(module, module.forward_bc, ema_module.forward)
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
for ema_module, module in self.module_pair_list:
self._update(ema_module, module, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def _update(self, ema_module, module, update_fn):
with torch.no_grad():
for ema_v, model_v in zip(ema_module.state_dict().values(), module.state_dict().values()):
ema_v.copy_(update_fn(ema_v, model_v))
@hankyul2 I like more the approach from timm as it doesn't involve changing forward method. Users could not use the forward method in xxx_step
method. So, I would recommend your first code but updating ema weights in on_train_batch_end
instead of on_before_backward
to match with timm train.py code.
Finally, I manage to implemented EMA as a callback by using state_dicts instead of the whole module as ModelEMAV2. I can try to make a PR or, at least, post the code here if you or anyone is interested (and I get permission for it). It works for 1 GPU and, likely, for Multi GPU though it isn't tested.
@hal-314 Cool 😎. I want to see it. Can you share your code?
@hankyul2 Here is the code. Be aware that you need overrides
package installed (pip install overrides
). If you don't want it, comment the import and the @overrides
decorator. I only use it to be sure that I'm actually overriding the method correctly.
Bits to be aware:
get_state_dict
method to filter some parameters that you don't want to include in EMA. For example, those in metrics. See its doc.I hope it's useful.
from copy import deepcopy
from typing import Optional, Union, Dict, Any
import pytorch_lightning as pl
import torch
from overrides import overrides
from pytorch_lightning.utilities import rank_zero_only
class EMA(pl.Callback):
"""Implements EMA (exponential moving average) to any kind of model.
EMA weights will be used during validation and stored separately from original model weights.
How to use EMA:
- Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See
https://github.com/rwightman/pytorch-image-models/issues/102
- Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and
https://github.com/rwightman/pytorch-image-models/issues/224
- For object detection, SWA usually works better. See https://github.com/timgaripov/swa/issues/16
Implementation detail:
- See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914
- When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
performance.
"""
def __init__(self, decay: float = 0.9999, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True):
super().__init__()
self.decay = decay
self.ema_device: str = f"{ema_device}" if ema_device else None # perform ema on different device from the model
self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False # Only works if CUDA is available
self.ema_state_dict: Dict[str, torch.Tensor] = {}
self.original_state_dict = {}
self._ema_state_dict_ready = False
@staticmethod
def get_state_dict(pl_module: pl.LightningModule):
"""Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
For example, in pl_module has metrics, you don't want to return their parameters.
code:
# Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
# like losses, metrics, etc.
patterns_to_ignore = ("metrics1", "metrics2")
return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
"""
return pl_module.state_dict()
@overrides
def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
# Only keep track of EMA weights in rank zero.
if not self._ema_state_dict_ready and pl_module.global_rank == 0:
self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
if self.ema_device:
self.ema_state_dict = {k: tensor.to(device=self.ema_device) for k, tensor in self.ema_state_dict.items()}
if self.ema_device == "cpu" and self.ema_pin_memory:
self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()}
self._ema_state_dict_ready = True
@rank_zero_only
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
# Update EMA weights
with torch.no_grad():
for key, value in self.get_state_dict(pl_module).items():
ema_value = self.ema_state_dict[key]
ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)
@overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
assert self.ema_state_dict.keys() == self.original_state_dict.keys(), \
f"There are some keys missing in the ema static dictionary broadcasted. " \
f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
pl_module.load_state_dict(self.ema_state_dict, strict=False)
if pl_module.global_rank > 0:
# Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
self.ema_state_dict = {}
@overrides
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=False)
@overrides
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}
@overrides
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"]
self.ema_state_dict = callback_state["ema_state_dict"]
@justusschock @mathemusician Do you think that lightning should include an EMA callback? Maybe, it could go to bolts.
@hal-314 wow... I think it is good.(but I am not a maintainer or something)
Personally I think we should include this. However, not sure where it belongs (Currently I'd say not to lightning core but either to flash or bolts)
cc @tchaton @ethanwharris for opinions on this
@hal-314 thanks for your implementation! I agree that EMA would be very useful to have in Lightning.
I tested your implementation with default parameters (so ema_device=None
etc) and it seems to work well on a single GPU. In multi-gpu, the assertion in on_validation_start
fails on GPUs other than 0. It appears that the ema_state_dict
is not broadcast successfully to all devices (it is an empty dict).
@flukeskywalker Glad to de that it's useful for you :)
If you fix the multi gpu code, could you mind to share the fix? So, others can use it.
@hal-314 Thank you for sharing your code. I test it with 2 gpus in ddp mode.
When pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
is executed, OOM error occurs.
Can I ask you how validation steps works in ddp mode?? or any document that I can reference?
Whole error logs is in below.
Traceback (most recent call last):
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 146, in run
self.on_advance_end()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 242, in on_advance_end
self._run_validation()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 337, in _run_validation
self.val_loop.run()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 140, in run
self.on_run_start(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 95, in on_run_start
self._on_evaluation_start()
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 179, in _on_evaluation_start
self.trainer.call_hook("on_validation_start", *args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
callback_fx(*args, **kwargs)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 216, in on_validation_start
callback.on_validation_start(self, self.lightning_module)
File "/home/hankyul/private/SuperConvergence/src/ema.py", line 134, in on_validation_start
pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 411, in broadcast
broadcast_object_list(obj, src, group=_group.WORLD)
File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1840, in broadcast_object_list
object_list[i] = _tensor_to_object(obj_view, obj_size)
File "/opt/conda/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 1532, in _tensor_to_object
return _unpickler(io.BytesIO(buf)).load()
File "/opt/conda/lib/python3.7/site-packages/torch/storage.py", line 161, in _load_from_bytes
return torch.load(io.BytesIO(b))
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 608, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 787, in _legacy_load
result = unpickler.load()
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 743, in persistent_load
deserialized_objects[root_key] = restore_location(obj, location)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
result = fn(storage, location)
File "/opt/conda/lib/python3.7/site-packages/torch/serialization.py", line 155, in _cuda_deserialize
return storage_type(obj.size())
File "/opt/conda/lib/python3.7/site-packages/torch/cuda/__init__.py", line 606, in _lazy_new
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
@hankyul2 Sorry but I don't have experience with multi-gpu in lightning. From the stack, it seems that the OOM occurs when broadcasting the state. Here is where I find broadcast
docs in Lightning.
@hankyul2 I tried the code @hal-314 in my 2-gpus 1080ti machine. And it worked well. Maybe the out of memory error is just literally...
My experiment configuration as below:
Model: BertModelForSequenceClassification
max_length: 50
padding_to_max_length: True
batch_size: 4(per device)
When I trained without ema, the gpu memory usage is 3311MB each. And turned on the ema, gpu 0 was 3851MB, gpu 1 was 3311MB.
@hal-314 @hankyul2 sorry, I made a mistake in my last code. And I found the broadcast
in on_validation_start
didn't work as espect. I change it as follow:
@overrides
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self._ema_state_dict_ready:
return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.
self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
self.ema_state_dict = ema_state_dict
since the broadcast
does not change variables in-place(in dist.py
, when the self.rank != 0
is True
). The callstack I found as follows:
pl_module.training_type_plugin.broadcast(self.ema_state_dict, 0)
-> ddp_spawn.broadcast(obj, src)
-> LightningDistributed.broadcast(obj, group)
and source code as follows
# pytorch_lightning/plugins/training_type/ddp_spawn.py
def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
return self.dist.broadcast(obj)
# pytorch_lightning/distributed/dist.py
class LightningDistributed:
def __init__(self, rank=None, device=None):
self.rank = rank
self.device = device
def broadcast(self, obj: Any, group=_group.WORLD):
# always wrap into a list so it can be broadcasted.
obj = [obj]
if self.rank != 0:
obj = [None] * len(obj)
broadcast_object_list(obj, 0, group=group or _group.WORLD)
return obj[0]
the gpu usage showed below:
without ema:
| 0 N/A N/A 15733 C ...da3/envs/pl149/bin/python 4645MiB |
| 1 N/A N/A 15750 C ...da3/envs/pl149/bin/python 4645MiB |
with ema:
| 0 N/A N/A 12561 C ...da3/envs/pl149/bin/python 5509MiB |
| 0 N/A N/A 12577 C ...da3/envs/pl149/bin/python 1227MiB |
| 1 N/A N/A 12577 C ...da3/envs/pl149/bin/python 5063MiB |
environment: pytorch==1.8.0 pytorch-lightning==1.4.9 gpus: 1080ti * 2
I can confirm that @sevenights's fix works for me too with pytorch-lightning==1.5.8
I would love to see an implementation of EMA. I just migrated from ignite to lightning and lacking an EMA callback really holds me back.
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
@hankyul2 It is very likely that the gradient was calculated during validation, which resulted in out of memory. You can try torch.set_grad_enabled(False) in your validation process.
@AbyssGaze Thank you for your suggestion.
@Borda Hi, any plan on landing this feature?
Picking this back up as @lucidrains related issue requires EMA.
I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.
@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).
On those situations, PL doesn't call callbacks.on_load_checkpoint
.
To fix it, you will need to use a custom trainer, so you can comment this line https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py#L1068
@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).
On those situations, PL doesn't call
callbacks.on_load_checkpoint
.To fix it, you will need to use a custom trainer, so you can comment this line https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py#L1068
Thanks for the heads up! So for fit it would be fine, just an issue for validate/test? Should see what the level of effort is required to fix this.
Validation and test sure. Finetune, I don't think so although I didn't check
Picking this back up as @lucidrains related issue requires EMA.
I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.
yeah, so i'm trying to figure out whether this issue is a blocker to using lightning for a project
the project involves a model containing multiple subnetworks. during training, each subnetwork has an EMA versioned that is updated every so number of training steps (say 10)
on validation time, i need to be able to call all the EMAed versions of all the subnetworks sequentially. this does not have to be distributed
will that be doable given this open issue?
will that be doable given this open issue?
yes absolutely, I think because of how close the connection between the EMA weights is to the actual model/subnetworks, I think it would be best to start by adding the logic directly into the pl.LightningModule
. This open issue addresses more of a general approach to keeping a EMAed version of the entire model, but in your case this generality isn't necessary.
will that be doable given this open issue?
yes absolutely, I think because of how close the connection between the EMA weights is to the actual model/subnetworks, I think it would be best to start by adding the logic directly into the
pl.LightningModule
. This open issue addresses more of a general approach to keeping a EMAed version of the entire model, but in your case this generality isn't necessary.
Any updates on this 👀 ?
Any updates? After lightning v1.8 update, EMA callback implemented by @hal-314 in this issue has been deprecated.
Sorry for the late response here, within NeMo I have added an EMA callback which we have tested/used. This is based on the PyTorch Lightning Callback, can be seen here: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
We're doing some performance improvements that may require more involvement from NeMo, so to use separately would require some stripping down. Progress can be seen here: https://github.com/NVIDIA/NeMo/pull/5169
Now that the ema callback implemented by NeMo has been completed, so can this callback be integrated directly into pytorch lightning?
I too would be quite interested in this feature, based upon a model I am replicating that uses EMA. @justusschock
@Borda just curious if where this is on the roadmap?
@Borda just curious if where this is on the roadmap?
I think we can add it as experimental callback :)
@Borda I would be super happy to beta test this ASAP.
The EMA callback in Nemo shared above is apache-2.0
@lantiga, what are your thoughts? Maybe implement it in Bolts?
PyTorch added support for this with https://github.com/pytorch/pytorch/pull/94820 (commit)
@carmocca Perhaps a doc update then, given that lightning suggests SWA where EMA is sometimes superior?
I haven't tried the newly added EMA in PyTorch. Just wanted to share the info. If anybody gives it a shot, we would accept a docs contribution showing how to use it with Lightning.
After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...
Just curious about the status here - Is there a recommended approach for this that the community has settled on?
I am also curious either. What's going on with important EMA callback..?
Also would love this! SWA is great and feel like EMA should be even easier to add?
Sorry for the late response here, within NeMo I have added an EMA callback which we have tested/used. This is based on the PyTorch Lightning Callback, can be seen here: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
We're doing some performance improvements that may require more involvement from NeMo, so to use separately would require some stripping down. Progress can be seen here: NVIDIA/NeMo#5169
That's great! Thank you.
After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...
https://github.com/NVIDIA/NeMo/pull/5169#issuecomment-1485958187
I hope this will be helpful.
@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?
@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?
@zhong-yy There's a evaluate_ema_weights_instead: bool = False
in the EMA callback to choose whether to use the original model or the EMA model to evaluate the metrics.
🚀 Feature
How about add EMA as callback?
Motivation
I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.
Pitch
If user add ema as callback, ema is applied for validation and test.
Alternatives
Of course, you can add ema as tutorial. like below snippets
Additional context
If you enjoy Lightning, check out our other projects! âš¡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @borda