lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

How to build optimizer #230

Closed pfeatherstone closed 6 months ago

pfeatherstone commented 6 months ago

Looking at

https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L263

https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215

https://github.com/ultralytics/ultralytics/blob/d021524e850acfa393ec25d4ecb9c4c761cca688/ultralytics/engine/trainer.py#L688

a few repositories carefully build optimizers by splitting parameters into groups, which will either experience weight decay or not. All of them agree biases of any kind don't while kernel weights from nn.Linear, nn.ConvNd do. This repository has many kind of parameters. My question is: where do they fall?

A shortlist of parameters I'm not sure about:

Thank you

pfeatherstone commented 6 months ago

Currently i'm using:

def createOptimizer(model: torch.nn.Module, betas=(0.9,0.95), lr=0.001, decay=0.1):
    blacklistModules = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) + (nn.Embedding, ScaleNorm, RMSNorm)
    blacklistNames   = ["bias", "memory_tokens", 'mem_k', 'mem_v']
    decay_params   = []
    nodecay_params = []
    for module_name, module in self.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if any(substr in fullname for substr in blacklistNames) or isinstance(module, blacklistModules):
                nodecay_params.append(param)
            else:
                decay_params.append(param)

    ndecayed            = len(decay_params)
    nnodecayed          = len(nodecay_params)
    ntotal              = len(list(filter(lambda p: p.requires_grad, self.parameters())))
    assert ndecayed + nnodecayed == ntotal, f"bad split: {ndecayed} + {nnodecayed} != {ntotal}"
    optim_groups = [
        {'params': decay_params,   'weight_decay': decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
    return optimizer

I've put memory tokens in the blacklist, i.e. parameters that don't decay. Not sure if that's correct. Layers like ScaleNorm and RMSNorm I'm treating like other pytorch normalization layers, and therefore also don't decay

pfeatherstone commented 6 months ago

Basically, i've only just started playing with optimizers and found that they have a massive influence on convergence rate and stability. Duh.

pfeatherstone commented 6 months ago

Can anybody think of any other layers/parameters that shouldn't decay ?

lucidrains commented 6 months ago

@pfeatherstone just use https://github.com/lucidrains/pytorch-custom-utils/blob/main/pytorch_custom_utils/get_adam_optimizer.py#L15 will suit 95% of your optimizer needs

lucidrains commented 6 months ago

pip install pytorch-custom-utils

from pytorch_custom_utils import get_adam_optimizer

lucidrains commented 6 months ago

@pfeatherstone and yeah, typically you just filter out any parameters with ndims <= 1, however, i've also heard from some researchers that it doesn't matter, ymmv

this is out of the scope for this repository though, recommend you just read some papers and decide for yourself

lucidrains commented 6 months ago

@pfeatherstone or hop on eleutherai and consult the crowd intelligence there. everyone has their own opinions about optimizers

pfeatherstone commented 6 months ago

@lucidrains Thank you. It looks like you are doing what nanogpt is doing. That does mean you are decaying normalization weights. I'll have a fiddle. Sorry if this is out of scope.

lucidrains commented 6 months ago

@pfeatherstone well, it isn't i'm doing what Karpathy is doing; we are both following an early practice for the original transformer training from Brain. however, whether it really matters, or is just passed down superstition, is still up for a future research paper to decide