Instantiating LitEma with LitEMA(self, ...) will lead to incorrect model parameter names stored in the dictionary, resulting in an error.
Because of the default setting of "use_num_updates=True" in LitEma, the declared decay factor isn't used, especially for few (<1000) epochs. In the "forward" method, LitEma overwrites decay with (1+num_updates)/(10+updates), which only becomes significant after many updates (e.g., (1+1000)/(10+1000) = 0.9911)).
Recommend updates as:
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self.model, decay=ema_decay, use_num_updates=False)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
and
def on_train_epoch_end(self):
if self.use_ema:
self.model_ema(self.model) # Update EMA shadow parameters with latest model weights
self.model_ema.copy_to(self.model) # Copy EMA shadow parameters to model weights
A couple separate issues:
Recommend updates as:
if self.use_ema: self.ema_decay = ema_decay assert 0. < ema_decay < 1. self.model_ema = LitEma(self.model, decay=ema_decay, use_num_updates=False) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
and
def on_train_epoch_end(self): if self.use_ema: self.model_ema(self.model) # Update EMA shadow parameters with latest model weights self.model_ema.copy_to(self.model) # Copy EMA shadow parameters to model weights