GT-RIPL / Selective-Projection-Decay

This repo implements the Adam + SPD (selective projection decay) regularization.
3 stars 0 forks source link

Selective-Projection-Decay

This repo implements the AdamSPD optimizer in the paper [Rethinking Weight Decay for Robust Fine-Tuning of Foundation Models]().

Use AdamSPD in Your Project

from adamSPD import AdamSPD
optimizer_params = {
            "lr": args.lr,
            "weight_decay": args.weight_decay,
} # Initalize optimizer parameters
params_to_opt = [x[1] for x in model.named_parameters() if x[1].requires_grad]
params_anchor = copy.deepcopy(params_to_opt) # Cache pre-trained model weights 
param_group = [{'params':params_to_opt,
                'pre': params_anchor}]
optimizer = AdamSPD(param_group,**optimizer_params)
from adamSPD import AdamSPD
optimizer_params = {
            "lr": args.lr,
            "weight_decay": args.weight_decay,
} # Initalize optimizer parameters
params_to_opt = [x[1] for x in model.named_parameters() if x[1].requires_grad]
param_group = [{'params':params_to_opt,
                'pre': None}]
optimizer = AdamSPD(param_group,**optimizer_params)