Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.2k stars 3.37k forks source link

Add feature Exponential Moving Average (EMA) #10914

Open hankyul2 opened 2 years ago

hankyul2 commented 2 years ago

🚀 Feature

How about add EMA as callback?

Motivation

I have had difficulty in applying ema. I think it would be nice if there are EMA as callback.

Pitch

If user add ema as callback, ema is applied for validation and test.

Alternatives

Of course, you can add ema as tutorial. like below snippets

class EMA(nn.Module):
    """ Model Exponential Moving Average V2 from timm"""
    def __init__(self, model, decay=0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

class BasicModule(LightningModule):
    def __init__(self, lr=0.01, use_ema=False):
        super().__init__()
        self.model = models.resnet18(pretrained=False)
        self.model_ema = EMA(self.model, decay=0.9) if use_ema else None
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

        metric = MetricCollection({'top@1': Accuracy(top_k=1), 'top@5': Accuracy(top_k=5)})
        self.train_metric = metric.clone(prefix='train_')
        self.valid_metric = metric.clone(prefix='valid_')

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        return self.shared_step(*batch, self.train_metric)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(*batch, self.valid_metric)

    def shared_step(self, x, y, metric):
        y_hat = self.model(x) if self.training or self.model_ema is None else self.model_ema.module(x)
        loss = self.criterion(y_hat, y)
        self.log_dict(metric(y_hat, y), prog_bar=True)
        return loss

    def configure_optimizers(self):
        return SGD(self.model.parameters(), lr=self.lr)

    def on_before_backward(self, loss: torch.Tensor) -> None:
        if self.model_ema:
            self.model_ema.update(self.model)

Additional context


If you enjoy Lightning, check out our other projects! âš¡

cc @borda

je-santos commented 2 months ago

Any update on this?

lyndonlauder commented 1 week ago

This should have been added long ago. @williamFalcon please try to get your team to move this along, It has been a stopping block with lightning for many users for years now.