karpathy / minGPT

A minimal PyTorch re-implementation of the OpenAI GPT (Generative Pretrained Transformer) training
MIT License
20.31k stars 2.53k forks source link

Simplifying weigh decay checking doesn't work #112

Closed rabinadk1 closed 1 year ago

rabinadk1 commented 1 year ago

First, I thought the code below could be simplified by separating parameters for weight decay.

decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
    for pn, p in m.named_parameters():
        fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
        # random note: because named_modules and named_parameters are recursive
        # we will see the same tensors p many many times. but doing it this way
        # allows us to know which parent module any tensor p belongs to...
        if pn.endswith('bias'):
            # all biases will not be decayed
            no_decay.add(fpn)
        elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
            # weights of whitelist modules will be weight decayed
            decay.add(fpn)
        elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
            # weights of blacklist modules will NOT be weight decayed
            no_decay.add(fpn)

I found this code to be succinct. But interestingly, this code resulted in duplicate parameters.

decay = set()
no_decay = set()
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
    for pn, p in m.named_parameters():
        fpn = f"{mn}.{pn}" if mn else pn  # full param name
        if pn.endswith("bias") or isinstance(m, blacklist_weight_modules):
            no_decay.add(fpn)
        else:
            decay.add(fpn)

Can anyone explain me why the changed code results in duplicate parameters?

rjarun8 commented 1 year ago

This is the test case matrix where the modified code behaves differently compared with the original code. The linear module maps to the whitelist and the rest belongs to the blacklist. This is generally the idea behind it.

Test Case | Module | Parameter | Original Code | Modified Code -- | -- | -- | -- | -- 1 | Linear | weight | decay | decay 1 | Linear | bias | no_decay | no_decay 2 | LayerNorm | weight | no_decay | no_decay 2 | LayerNorm | bias | no_decay | no_decay 3 | Embedding | weight | no_decay | no_decay 4 | Custom | custom_param | N/A | decay
rjarun8 commented 1 year ago

modifying the else should help fix it.

code:

decay = set() no_decay = set() blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = f"{mn}.{pn}" if mn else pn # full param name if pn.endswith("bias") or isinstance(m, blacklist_weight_modules): no_decay.add(fpn) elif pn.endswith("weight"): decay.add(fpn)

rabinadk1 commented 1 year ago

modifying the else should help fix it.

code:

decay = set() no_decay = set() blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = f"{mn}.{pn}" if mn else pn # full param name if pn.endswith("bias") or isinstance(m, blacklist_weight_modules): no_decay.add(fpn) elif pn.endswith("weight"): decay.add(fpn)

Thanks, missed the logic there.