lucidrains / nGPT-pytorch

Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI
MIT License
201 stars 10 forks source link
artificial-intelligence cosine-normalization deep-learning hypersphere normalization transformers

nGPT (normalized GPT) - Pytorch

Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith.

This type of network should also be studied in the context of continual learning and loss of plasticity

Adaptation to vision transformers is here

Install

$ pip install nGPT-pytorch

Usage

import torch
from nGPT_pytorch import nGPT

model = nGPT(
    num_tokens = 256,
    dim = 512,
    depth = 4,
    attn_norm_qk = True
)

x = torch.randint(0, 256, (2, 2048))

loss = model(x, return_loss = True)
loss.backward()

logits = model(x) # (2, 2048, 256)

Test

Enwik8

$ python train.py

Citations

@inproceedings{Loshchilov2024nGPTNT,
    title   = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
    author  = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:273026160}
}
@article{Luo2017CosineNU,
    title     = {Cosine Normalization: Using Cosine Similarity Instead of Dot Product in Neural Networks},
    author    = {Chunjie Luo and Jianfeng Zhan and Lei Wang and Qiang Yang},
    journal   = {ArXiv},
    year      = {2017},
    volume    = {abs/1702.05870},
    url       = {https://api.semanticscholar.org/CorpusID:1505432}
}