Leminhbinh0209 / FinetuneVAE-SD

Fine-tune VAE of Stable Diffusion model
15 stars 3 forks source link

EMA Update Isn't Working as Expected #3

Open JStyborski opened 1 month ago

JStyborski commented 1 month ago

A couple separate issues:

  1. PyTorch Lightning module uses "on_train_epoch_end( )", so "train_epoch_end( )" isn't actually called.
  2. Instantiating LitEma with LitEMA(self, ...) will lead to incorrect model parameter names stored in the dictionary, resulting in an error.
  3. Because of the default setting of "use_num_updates=True" in LitEma, the declared decay factor isn't used, especially for few (<1000) epochs. In the "forward" method, LitEma overwrites decay with (1+num_updates)/(10+updates), which only becomes significant after many updates (e.g., (1+1000)/(10+1000) = 0.9911)).

Recommend updates as:

if self.use_ema: self.ema_decay = ema_decay assert 0. < ema_decay < 1. self.model_ema = LitEma(self.model, decay=ema_decay, use_num_updates=False) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

and

def on_train_epoch_end(self): if self.use_ema: self.model_ema(self.model) # Update EMA shadow parameters with latest model weights self.model_ema.copy_to(self.model) # Copy EMA shadow parameters to model weights

JStyborski commented 1 month ago

Alternatively, you might consider using on_train_batch_end( ) instead, which would result in many more EMA updates per epoch.

Leminhbinh0209 commented 1 month ago

Thank you for your helpful comments!