VectorSpaceLab / OmniGen

OmniGen: Unified Image Generation. https://arxiv.org/pdf/2409.11340
MIT License
1.76k stars 110 forks source link

Saving of EMA state dict in train.py #59

Closed brycegoh closed 4 days ago

brycegoh commented 4 days ago

Hi, can I check if this is a typo in the training script?

if ema_state_dict is not None:
  checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema"
  os.makedirs(checkpoint_path, exist_ok=True)
  ## Should we be saving the ema_state_dict instead?
  torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
  processor.text_tokenizer.save_pretrained(checkpoint_path)
  model.llm.config.save_pretrained(checkpoint_path)

Should torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) be torch.save(ema_state_dict, os.path.join(checkpoint_path, "model.pt")) instead?

Link to code snippet: https://github.com/VectorSpaceLab/OmniGen/blob/d89f9d42dde00d55a886a49144178911b5309830/train.py#L271-L275

staoxiao commented 4 days ago

You're right. Thanks for the reminder! You can submit a PR to become a contributor, or I can fix the issue myself later.

brycegoh commented 4 days ago

Created a PR here https://github.com/VectorSpaceLab/OmniGen/pull/60

Let me know if I missed anything else. Thanks!