apple / ml-mdm

Train high-quality text-to-image diffusion models in a data & compute efficient manner
https://machinelearning.apple.com/research/matryoshka-diffusion-models
MIT License
443 stars 30 forks source link

Add type hinting to train_batch function #24

Open luke-carlson opened 3 weeks ago

luke-carlson commented 3 weeks ago

trainer.py's train_batch function has a number of arguments, it would be nice if each of these had an associated type hint, eg bool, int, EmaModel, etc.

def train_batch(
    model,
    sample,
    optimizer,
    scheduler,
    logger,
    args,
    grad_scaler=None,
    accumulate_gradient=False,
    num_grad_accumulations=1,
    ema_model=None,
    loss_factor=1,
emmagarr commented 2 weeks ago

I can work on this!

emmagarr commented 1 week ago

https://github.com/apple/ml-mdm/pull/33