facebookresearch / schedule_free

Schedule-Free Optimization in PyTorch
Apache License 2.0
1.91k stars 65 forks source link

Compatibility with MuP (Maximal Update Parametrization) #24

Closed simonguozirui closed 4 months ago

simonguozirui commented 6 months ago

Curious if anyone has tried using the schedule-free optimizer while training with Maximal Update Parametrization, (paper, implementation)?

norikazu99 commented 4 months ago

MuP and Schedule Free seem to be complementary despite schedule_free theory not being well understood. The former is an intialization framework, whereas the latter is kind of a new optimizer/schedule (family of optimizers/schedules). The implementations seems to be compatible as well since the lr etc can be initialized by mup, and then pass the param_groups to the schedule_free optimizer (the wrappers seems to be compatible too). Mup only modifies the lr etc at init, and schedule_free uses the param_groups for every step which respects mup's implementation constraint with regards to setting lr respectively to what is stored in param_groups for schedulers.

from collections import defaultdict
import torch
from schedulefree import AdamWScheduleFree, AdamWScheduleFreeClosure

def process_param_groups(params, **kwargs):
    param_groups = list(params)
    if not isinstance(param_groups[0], dict):
        param_groups = [{"params": param_groups}]
    for param_group in param_groups:
        if "lr" not in param_group:
            param_group["lr"] = kwargs["lr"]
        if "weight_decay" not in param_group:
            param_group["weight_decay"] = kwargs.get("weight_decay", 0.)
    return param_groups

def MuAdamW_ScheduleFree(params, impl=AdamWScheduleFree, decoupled_wd=False, **kwargs):
    """Adam with μP scaling.

    Note for this to work properly, your model needs to have its base shapes set
    already using `mup.set_base_shapes`.

    Inputs:
        impl: the specific Adam-like optimizer implementation from torch.optim or
            elsewhere 
        decoupled_wd: if True, skips the mup scaling for weight decay, which should
            be used for optimizer implementations that decouple weight decay from
            learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
    Outputs:
        An instance of `impl` with refined parameter groups, each of which has the correctly
        scaled learning rate according to mup.
    """
    new_param_groups = []
    for param_group in process_param_groups(params, **kwargs):
        # For every existing param group, we split into several new groups
        def new_group():
            new_g = {k:v for k, v in param_group.items() if k != "params"}
            new_g["params"] = []
            return new_g
        # The matrix-like weights might need multiple groups since weights
        # might have different width multipliers
        matrix_like_p = defaultdict(new_group) # key is width_mult
        vector_like_p = new_group()
        for p in param_group["params"]:
            assert hasattr(p, "infshape"), (
                f"A parameter with shape {p.shape} does not have `infshape` attribute. "
                "Did you forget to call `mup.set_base_shapes` on the model?")
            if p.infshape.ninf() == 2:
                matrix_like_p[p.infshape.width_mult()]["params"].append(p)
            elif p.infshape.ninf() > 2:
                raise NotImplementedError("more than 2 inf dimensions")
            else:
                vector_like_p["params"].append(p)
        for width_mult, group in matrix_like_p.items():
            # Scale learning rate and weight decay accordingly
            group["lr"] /= width_mult
            if not decoupled_wd:
                group["weight_decay"] *= width_mult
        new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
    return impl(new_param_groups, **kwargs)

I'm about to use this. It likely does the trick. Just copied and pasted the relevant parts from both repos. Will share how it does in the future if no one else does. Please lmk if anyone thinks I missed something.

simonguozirui commented 4 months ago

@norikazu99 thanks so much for looking into this!