Add feature Exponential Moving Average (EMA) #10914

Open hankyul2 opened 2 years ago

hankyul2 commented 2 years ago

🚀 Feature

How about add EMA as callback?


I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.


If user add ema as callback, ema is applied for validation and test.


Of course, you can add ema as tutorial. like below snippets

class EMA(nn.Module):
    """ Model Exponential Moving Average V2 from timm"""
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

class BasicModule(LightningModule):
    def __init__(self, lr=0.01, use_ema=False):
        self.model = models.resnet18(pretrained=False)
        self.model_ema = EMA(self.model, decay=0.9) if use_ema else None
        self.criterion = nn.CrossEntropyLoss() = lr

        metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
        self.train_metric = metric.clone(prefix='train_')
        self.valid_metric = metric.clone(prefix='valid_')

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        return self.shared_step(*batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(*batch, self.valid_metric)

    def shared_step(self, x, y, metric):
        y_hat = self.model(x) if or self.model_ema is None else self.model_ema.module(x)
        loss = self.criterion(y_hat, y)
        self.log_dict(metric(y_hat, y), prog_bar=True)
        return loss

    def configure_optimizers(self):
        return SGD(self.model.parameters(),

    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:

Additional context

justusschock commented 2 years ago

Hi, as stated in this can be done by replacing just one part of our SWA.

hankyul2 commented 2 years ago

@justusschock thank you for your reply.

Is there a way to use ema using swa?

I checked the link. For me, swa seems to update lr scheduler and model weights together. Am I right?

mathemusician commented 2 years ago

@hankyul2, I believe this is how it would be implemented:

from pytorch_lightning.callbacks import StochasticWeightAveraging

class EMA_Callback(StochasticWeightAveraging):
    def __init__(self, decay=0.9999):
        self.decay = decay

    def avg_fn (
        averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
    ) -> torch.FloatTensor:
        e = averaged_model_parameter
        m = model_parameter
        return self.decay * e + (1. - self.decay) * m

Let me know if you have any problems. I just learned about SWA because of your issue!

hankyul2 commented 2 years ago

@mathemusician thank you.

hal-314 commented 2 years ago

@hankyul2 I don't think that @mathemusician solution is equivalent. avg_fn is called once per epoch while EMA updates happens every training step. I don't think that EMA can be implemented with SWA callback,

hankyul2 commented 2 years ago

@hal-314 yeah. I think so.

hankyul2 commented 2 years ago

@hal-314 @mathemusician

I have implemented EMA Callback with simple functionality.

I think much more options should be added. For example, save_weight, ema_step_period, etc.

If you find it helpful, please let me know. Then I will implement it more. If you don't, close this issue or leave any comments.

from copy import deepcopy

import torch
from pytorch_lightning import Callback

class EMACallback(Callback):
    def __init__(self, decay=0.995):
        self.decay = decay
        self.module_pair_list = []

    def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        def forward_wrapper(module, org, ema):
            def forward(*args, **kwargs):
                return org(*args, **kwargs) if else ema(*args, **kwargs)
            return forward

        modules = list(filter(lambda x: len(list(x[1].parameters())) > 0, pl_module.named_children()))

        for name, module in modules:
            ema_module = deepcopy(module)
            self.module_pair_list.append((ema_module, module))
            pl_module.add_module(f'EMA_{name}', ema_module)
            module.forward_bc = module.forward
            module.forward = forward_wrapper(module, module.forward_bc, ema_module.forward)

    def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        for ema_module, module in self.module_pair_list:
            self._update(ema_module, module, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def _update(self, ema_module, module, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(ema_module.state_dict().values(), module.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))
hal-314 commented 2 years ago

@hankyul2 I like more the approach from timm as it doesn't involve changing forward method. Users could not use the forward method in xxx_step method. So, I would recommend your first code but updating ema weights in on_train_batch_end instead of on_before_backward to match with timm code.

Finally, I manage to implemented EMA as a callback by using state_dicts instead of the whole module as ModelEMAV2. I can try to make a PR or, at least, post the code here if you or anyone is interested (and I get permission for it). It works for 1 GPU and, likely, for Multi GPU though it isn't tested.

hankyul2 commented 2 years ago

@hal-314 Cool 😎. I want to see it. Can you share your code?

hal-314 commented 2 years ago

@hankyul2 Here is the code. Be aware that you need overrides package installed (pip install overrides). If you don't want it, comment the import and the @overrides decorator. I only use it to be sure that I'm actually overriding the method correctly.

Bits to be aware:

I hope it's useful.

from copy import deepcopy
from typing import Optional, Union, Dict, Any

import pytorch_lightning as pl
import torch
from overrides import overrides
from pytorch_lightning.utilities import rank_zero_only

class EMA(pl.Callback):
    """Implements EMA (exponential moving average) to any kind of model.
    EMA weights will be used during validation and stored separately from original model weights.

    How to use EMA:
        - Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See

        - Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See
          discussions in: and

        - For object detection, SWA usually works better. See

    Implementation detail:
        - See EMA in Pytorch Lightning:
        - When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory.
          This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited
          resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve
    def __init__(self, decay: float = 0.9999, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True):
        self.decay = decay
        self.ema_device: str = f"{ema_device}" if ema_device else None  # perform ema on different device from the model
        self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False  # Only works if CUDA is available
        self.ema_state_dict: Dict[str, torch.Tensor] = {}
        self.original_state_dict = {}
        self._ema_state_dict_ready = False

    def get_state_dict(pl_module: pl.LightningModule):
        """Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out.
        For example, in pl_module has metrics, you don't want to return their parameters.

            # Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached
            # like losses, metrics, etc.
            patterns_to_ignore = ("metrics1", "metrics2")
            return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items()))
        return pl_module.state_dict()

    def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
        # Only keep track of EMA weights in rank zero.
        if not self._ema_state_dict_ready and pl_module.global_rank == 0:
            self.ema_state_dict = deepcopy(self.get_state_dict(pl_module))
            if self.ema_device:
                self.ema_state_dict = {k: for k, tensor in self.ema_state_dict.items()}

            if self.ema_device == "cpu" and self.ema_pin_memory:
                self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()}

        self._ema_state_dict_ready = True

    def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
        # Update EMA weights
        with torch.no_grad():
            for key, value in self.get_state_dict(pl_module).items():
                ema_value = self.ema_state_dict[key]
                ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)

    def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
        pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
        assert self.ema_state_dict.keys() == self.original_state_dict.keys(), \
            f"There are some keys missing in the ema static dictionary broadcasted. " \
            f"They are: {self.original_state_dict.keys() - self.ema_state_dict.keys()}"
        pl_module.load_state_dict(self.ema_state_dict, strict=False)

        if pl_module.global_rank > 0:
            # Remove ema state dict from the memory. In rank 0, it could be in ram pinned memory.
            self.ema_state_dict = {}

    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not self._ema_state_dict_ready:
            return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

        # Replace EMA weights with training weights
        pl_module.load_state_dict(self.original_state_dict, strict=False)

    def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> dict:
        return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready}

    def on_load_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
    ) -> None:
        self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"]
        self.ema_state_dict = callback_state["ema_state_dict"]

@justusschock @mathemusician Do you think that lightning should include an EMA callback? Maybe, it could go to bolts.

hankyul2 commented 2 years ago

@hal-314 wow... I think it is good.(but I am not a maintainer or something)

justusschock commented 2 years ago

Personally I think we should include this. However, not sure where it belongs (Currently I'd say not to lightning core but either to flash or bolts)

cc @tchaton @ethanwharris for opinions on this

flukeskywalker commented 2 years ago

@hal-314 thanks for your implementation! I agree that EMA would be very useful to have in Lightning.

I tested your implementation with default parameters (so ema_device=None etc) and it seems to work well on a single GPU. In multi-gpu, the assertion in on_validation_start fails on GPUs other than 0. It appears that the ema_state_dict is not broadcast successfully to all devices (it is an empty dict).

hal-314 commented 2 years ago

@flukeskywalker Glad to de that it's useful for you :)

If you fix the multi gpu code, could you mind to share the fix? So, others can use it.

hankyul2 commented 2 years ago

@hal-314 Thank you for sharing your code. I test it with 2 gpus in ddp mode. When pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0) is executed, OOM error occurs. Can I ask you how validation steps works in ddp mode?? or any document that I can reference?

Whole error logs is in below.

Traceback (most recent call last):
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 1195, in _run
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 1274, in _dispatch
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 1284, in run_stage
    return self._run_train()
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 1314, in _run_train
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/", line 145, in run
    self.advance(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/", line 234, in advance
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/", line 146, in run
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/", line 242, in on_advance_end
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/", line 337, in _run_validation
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/", line 140, in run
    self.on_run_start(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/", line 95, in on_run_start
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/", line 179, in _on_evaluation_start
    self.trainer.call_hook("on_validation_start", *args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 1490, in call_hook
    callback_fx(*args, **kwargs)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/trainer/", line 216, in on_validation_start
    callback.on_validation_start(self, self.lightning_module)
  File "/home/hankyul/private/SuperConvergence/src/", line 134, in on_validation_start
    pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
  File "/home/hankyul/.local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/", line 411, in broadcast
    broadcast_object_list(obj, src, group=_group.WORLD)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/", line 1840, in broadcast_object_list
    object_list[i] = _tensor_to_object(obj_view, obj_size)
  File "/opt/conda/lib/python3.7/site-packages/torch/distributed/", line 1532, in _tensor_to_object
    return _unpickler(io.BytesIO(buf)).load()
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 161, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 608, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 787, in _legacy_load
    result = unpickler.load()
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 743, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 175, in default_restore_location
    result = fn(storage, location)
  File "/opt/conda/lib/python3.7/site-packages/torch/", line 155, in _cuda_deserialize
    return storage_type(obj.size())
  File "/opt/conda/lib/python3.7/site-packages/torch/cuda/", line 606, in _lazy_new
    return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
hal-314 commented 2 years ago

@hankyul2 Sorry but I don't have experience with multi-gpu in lightning. From the stack, it seems that the OOM occurs when broadcasting the state. Here is where I find broadcast docs in Lightning.

sevenights commented 2 years ago

@hankyul2 I tried the code @hal-314 in my 2-gpus 1080ti machine. And it worked well. Maybe the out of memory error is just literally...

My experiment configuration as below:

Model: BertModelForSequenceClassification
max_length: 50
padding_to_max_length: True
batch_size: 4(per device)

When I trained without ema, the gpu memory usage is 3311MB each. And turned on the ema, gpu 0 was 3851MB, gpu 1 was 3311MB.

sevenights commented 2 years ago

@hal-314 @hankyul2 sorry, I made a mistake in my last code. And I found the broadcast in on_validation_start didn't work as espect. I change it as follow:

def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
    if not self._ema_state_dict_ready:
        return  # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint.

    self.original_state_dict = deepcopy(self.get_state_dict(pl_module))
    ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0)
    self.ema_state_dict = ema_state_dict

since the broadcast does not change variables in-place(in, when the self.rank != 0 is True). The callstack I found as follows:

pl_module.training_type_plugin.broadcast(self.ema_state_dict, 0)
-> ddp_spawn.broadcast(obj, src)
-> LightningDistributed.broadcast(obj, group)

and source code as follows

# pytorch_lightning/plugins/training_type/
def broadcast(self, obj: object, src: int = 0) -> object:
    if not distributed_available():
        return obj
    return self.dist.broadcast(obj)
# pytorch_lightning/distributed/
class LightningDistributed:
    def __init__(self, rank=None, device=None):
        self.rank = rank
        self.device = device

    def broadcast(self, obj: Any, group=_group.WORLD):
        # always wrap into a list so it can be broadcasted.
        obj = [obj]

        if self.rank != 0:
            obj = [None] * len(obj)

        broadcast_object_list(obj, 0, group=group or _group.WORLD)

        return obj[0]

the gpu usage showed below:

without ema:

|    0   N/A  N/A     15733      C   ...da3/envs/pl149/bin/python     4645MiB |
|    1   N/A  N/A     15750      C   ...da3/envs/pl149/bin/python     4645MiB |

with ema:

|    0   N/A  N/A     12561      C   ...da3/envs/pl149/bin/python     5509MiB |
|    0   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     1227MiB |
|    1   N/A  N/A     12577      C   ...da3/envs/pl149/bin/python     5063MiB |

environment: pytorch==1.8.0 pytorch-lightning==1.4.9 gpus: 1080ti * 2

flukeskywalker commented 2 years ago

I can confirm that @sevenights's fix works for me too with pytorch-lightning==1.5.8

yoyololicon commented 2 years ago

I would love to see an implementation of EMA. I just migrated from ignite to lightning and lacking an EMA callback really holds me back.

AbyssGaze commented 2 years ago

@hankyul2 It is very likely that the gradient was calculated during validation, which resulted in out of memory. You can try torch.set_grad_enabled(False) in your validation process.

hankyul2 commented 2 years ago

@AbyssGaze Thank you for your suggestion.

Ir1d commented 2 years ago

@Borda Hi, any plan on landing this feature?

SeanNaren commented 2 years ago

Picking this back up as @lucidrains related issue requires EMA.

I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.

hal-314 commented 2 years ago

@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).

On those situations, PL doesn't call callbacks.on_load_checkpoint.

To fix it, you will need to use a custom trainer, so you can comment this line

SeanNaren commented 2 years ago

@SeanNaren Be aware that #5542 prevents to load automatically EMA weights on validation/testing only (trainer.validate / trainer.test).

On those situations, PL doesn't call callbacks.on_load_checkpoint.

To fix it, you will need to use a custom trainer, so you can comment this line

Thanks for the heads up! So for fit it would be fine, just an issue for validate/test? Should see what the level of effort is required to fix this.

hal-314 commented 2 years ago

Validation and test sure. Finetune, I don't think so although I didn't check

lucidrains commented 2 years ago

Picking this back up as @lucidrains related issue requires EMA.

I think we should include this somewhere ASAP. Bolts would be the easiest landing place for such a callback. Any disagreements @Borda @justusschock? If so I can make an issue to get it into Bolts.

yeah, so i'm trying to figure out whether this issue is a blocker to using lightning for a project

the project involves a model containing multiple subnetworks. during training, each subnetwork has an EMA versioned that is updated every so number of training steps (say 10)

on validation time, i need to be able to call all the EMAed versions of all the subnetworks sequentially. this does not have to be distributed

will that be doable given this open issue?

SeanNaren commented 2 years ago

will that be doable given this open issue?

yes absolutely, I think because of how close the connection between the EMA weights is to the actual model/subnetworks, I think it would be best to start by adding the logic directly into the pl.LightningModule. This open issue addresses more of a general approach to keeping a EMAed version of the entire model, but in your case this generality isn't necessary.

DanBigioi commented 2 years ago

will that be doable given this open issue?

yes absolutely, I think because of how close the connection between the EMA weights is to the actual model/subnetworks, I think it would be best to start by adding the logic directly into the pl.LightningModule. This open issue addresses more of a general approach to keeping a EMAed version of the entire model, but in your case this generality isn't necessary.

Any updates on this 👀 ?

BakerBunker commented 1 year ago

Any updates? After lightning v1.8 update, EMA callback implemented by @hal-314 in this issue has been deprecated.

SeanNaren commented 1 year ago

Sorry for the late response here, within NeMo I have added an EMA callback which we have tested/used. This is based on the PyTorch Lightning Callback, can be seen here:

We're doing some performance improvements that may require more involvement from NeMo, so to use separately would require some stripping down. Progress can be seen here:

BakerBunker commented 1 year ago

Now that the ema callback implemented by NeMo has been completed, so can this callback be integrated directly into pytorch lightning?

turian commented 1 year ago

I too would be quite interested in this feature, based upon a model I am replicating that uses EMA. @justusschock

turian commented 1 year ago

@Borda just curious if where this is on the roadmap?

Borda commented 1 year ago

@Borda just curious if where this is on the roadmap?

I think we can add it as experimental callback :)

turian commented 1 year ago

@Borda I would be super happy to beta test this ASAP.

The EMA callback in Nemo shared above is apache-2.0

Borda commented 1 year ago

@lantiga, what are your thoughts? Maybe implement it in Bolts?

carmocca commented 1 year ago

PyTorch added support for this with (commit)

turian commented 1 year ago

@carmocca Perhaps a doc update then, given that lightning suggests SWA where EMA is sometimes superior?

carmocca commented 1 year ago

I haven't tried the newly added EMA in PyTorch. Just wanted to share the info. If anybody gives it a shot, we would accept a docs contribution showing how to use it with Lightning.

shivammehta25 commented 1 year ago

After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...

jlotthammer commented 7 months ago

Just curious about the status here - Is there a recommended approach for this that the community has settled on?

Hans-digit commented 7 months ago

I am also curious either. What's going on with important EMA callback..?

lukasschmit commented 7 months ago

Also would love this! SWA is great and feel like EMA should be even easier to add?

KasuganoLove commented 6 months ago

Sorry for the late response here, within NeMo I have added an EMA callback which we have tested/used. This is based on the PyTorch Lightning Callback, can be seen here:

We're doing some performance improvements that may require more involvement from NeMo, so to use separately would require some stripping down. Progress can be seen here: NVIDIA/NeMo#5169

That's great! Thank you.

KasuganoLove commented 6 months ago

After using the Nemo callback, on loading the checkpoint, does it automatically load the EMA parameters for inference? Or do I need to switch the parameters somehow manually? @SeanNaren and others...

I hope this will be helpful.

zhong-yy commented 4 months ago

@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?

KasuganoLove commented 4 months ago

@flashszn @SeanNaren Thanks for sharing. When you save the top-k models, do you use the original model or the EMA model to evaluate the metrics?

@zhong-yy There's a evaluate_ema_weights_instead: bool = False in the EMA callback to choose whether to use the original model or the EMA model to evaluate the metrics.