microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

Should `base=None` be used in `set_base_shapes` for model used for tuning? #25

Open callumm-graphcore opened 1 year ago

callumm-graphcore commented 1 year ago

Hello! First of all, thank you for doing such great work and making it so accessible. I'm looking at using mup for a project but I'm a bit confused about how to set the base shapes for the smaller model used for hyperparameter tuning.

Let's say I want to train an MLP with hidden dimension 1024, and I want to muTransfer the best learning rate from an MLP with hidden dimension 128. My top-level code might look like this:

best_loss = float('inf')
best_lr = 0.

# Hyperparameter sweep with hidden dimension 128
for lr in learning_rates:

    small_mlp = MLP(hidden_dim=128)

    # use `base=None` in `set_base_shapes`
    small_mlp = mup.set_base_shapes(small_mlp, base=None)

    final_loss = full_training_loop(small_mlp, lr=lr)

    if final_loss < best_loss:
        best_loss = final_loss
        best_lr = lr

# Transfer optimal LR to large model

base_mlp = MLP(hidden_dim=128)
big_mlp = MLP(hidden_dim=1024)

big_mlp = mup.set_base_shapes(big_mlp, base=base_mlp)

ultimate_loss = full_training_loop(big_mlp, lr=best_lr)

or like this:

best_loss = float('inf')
best_lr = 0.

for lr in learning_rates:

    small_mlp = MLP(hidden_dim=128)

    # use a base model in `set_base_shapes`
    smaller_mlp = MLP(hidden_dim=32)
    small_mlp = mup.set_base_shapes(small_mlp, base=smaller_mlp)

    final_loss = full_training_loop(small_mlp, lr=lr)

    if final_loss < best_loss:
        best_loss = final_loss
        best_lr = lr

# Transfer optimal LR to large model

base_mlp = MLP(hidden_dim=128)
big_mlp = MLP(hidden_dim=1024)

big_mlp = mup.set_base_shapes(big_mlp, base=base_mlp)

ultimate_loss = full_training_loop(big_mlp, lr=best_lr)

Could you please clarify which of these would be correct? Thank you very much for your time!

thegregyang commented 1 year ago

Thanks for the kind words!

You should do the 2nd thing. base=None essentially means not using muP.

callumm-graphcore commented 1 year ago

Great, thanks Greg!