microsoft / Swin-Transformer

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
https://arxiv.org/abs/2103.14030
MIT License
13.96k stars 2.06k forks source link

no weight decay setting error in SimMIM pretraining #369

Open wanghaoyucn opened 1 month ago

wanghaoyucn commented 1 month ago

https://github.com/microsoft/Swin-Transformer/blob/f82860bfb5225915aca09c3227159ee9e1df874d/models/simmim.py#L154-L158

Hello, I found this function in class SimMIM will have an output {'encoder.cpb_mlp', 'encoder.logit_scale', 'encoder.relative_position_bias_table'}. When it is passed into the build_optimizer, finally it will call function check_keywords_in_name(name, skip_keywords) to check if we need to set weight decay of this parameter to 0.

https://github.com/microsoft/Swin-Transformer/blob/f82860bfb5225915aca09c3227159ee9e1df874d/optimizer.py#L76-L81

Sadly, 'encoder.cpb_mlp' in 'encoder.layers.0.blocks.0.attn.cpb_mlp.0.bias' == False, which means the weight decay of cpb_mlp is not 0 during pretraining. The right implementation of no_weight_decay_keywords would be:

@torch.jit.ignore 
 def no_weight_decay_keywords(self): 
     if hasattr(self.encoder, 'no_weight_decay_keywords'): 
         return {i for i in self.encoder.no_weight_decay_keywords()} 
     return {} 

Is this an intentional behavior or a bug? I appreciate your help!