GT-RIPL / Selective-Projection-Decay

This repo implements the Adam + SPD (selective projection decay) regularization.
3 stars 0 forks source link

Quick question #1

Open JohnMBrandt opened 1 week ago

JohnMBrandt commented 1 week ago

Thank you so much for this research. I've wondered for a long time whether weight decay was leading to suboptimal results when fine-tuning transformers.

I work on fine-tuning vision transformers, mostly within MMDetection and MMSegmentation, and have successfully ported this work to those toolkits.

I was wondering, though, how you suggest applying the optimizer when attaching a new head to a pretrained backbone? Your work is suggesting that only a few layers need to be adjusted, but the entirety of the head needs to be adjusted. Is there a way to use normal AdamW with constant weight decay on the head, and AdamSPD with variable weight decay on the backbone? or does it matter?

My approach has been to:

Modify 'params' to take the parameter name:

params = [{'params':params_to_opt,
                'pre': params_anchor,
                 'name': params_name}]

Selectively apply SPD if the parameter name includes 'backbone'

if 'backbone' not in group['param_names'][j]:
      new_p = new_p - weight_decay * new_p
      condition = 0
else:
     condition = - torch.sum(torch.mul(grad, param - pre))

if condition < 0.0:
       ratio = self._ratio(new_p, param, pre)
       new_p = new_p - weight_decay * ratio * (new_p - pre)
param.copy_(new_p)

Also, if interested, here is the MMEngine constructor that works to port it to MMSegmentation + MMDetection

import json
from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimWrapperConstructor
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS
import copy
import torch

@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class SPDWeightDecayConstructor(DefaultOptimWrapperConstructor):
    def add_params(self, params, module, prefix='', is_dcn_module=None):
        """Add all parameters of module to the params list.
        The parameters of the given module will be added to the list of param
        groups, with specific rules defined by paramwise_cfg.
        Args:
            params (list[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            prefix (str): The prefix of the module
            is_dcn_module (int|float|None): If the current module is a
                submodule of DCN, `is_dcn_module` will be passed to
                control conv_offset layer's learning rate. Defaults to None.
        """

        parameter_groups = {}

        #params_anchor = copy.deepcopy(params_to_opt)
        for i, value in enumerate(module.named_parameters()):
            if not value[1].requires_grad:
                continue  # frozen weights
            group_name = value[0]

            this_weight_decay = 0.01 if ('backbone' in value[0] and 'fpn' not in value[0]) else 1e-4
            scale = 1e-1 if (('backbone' in value[0] and 'fpn' not in value[0])) else 1
            this_weight_decay = this_weight_decay if 'norm' not in value[0] else 0

            parameter_groups[group_name] = {
                        'weight_decay': this_weight_decay,
                        'params': value[1],
                        'param_names': [],
                        'pre': [],
                        'lr_scale': scale,
                        'name': group_name,
                        'lr': scale * self.base_lr,
                    }
            parameter_groups[group_name]['pre'].append(copy.deepcopy(value[1]))
        params.extend(parameter_groups.values())
PotatoTian commented 1 week ago

Hi,

Thank you for the code pointer! If the added head is randomly initialized, you could set the corresponding 'pre': params_anchor to zero tensors. In this case, AdamSPD reduces to a learnable normal weight decay. Keep me posted on your results and discoveries. I am interested in seeing how generalizable this method is to different settings.

Best,