hunar4321 / reweight-gpt

Reweight GPT - a simple neural network using transformer architecture for next character prediction
https://www.brainxyz.com/
MIT License
47 stars 7 forks source link

Fourier Embeddings #3

Open mourad-ghafiri opened 8 months ago

mourad-ghafiri commented 8 months ago

I did a small experiment after watching your tutorial the idea is to convert each token (a word in my case) into a sin signal. I take a context_length word token and I sum up their signal (I add shift to represent position) the I let the neural network predict the next signal. I used Fourier transform to get the predicted token which is the max frequency in may case.

Hope this code will help for something :)

import torch
import torch.nn as nn
import torch.nn.functional as F

d_model = 1024

class MyModel(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.w1 = nn.Linear(d_model, d_model)

    def forward(self, inputs):
        t = torch.linspace(0, 1, self.d_model).requires_grad_(False)
        signal = torch.zeros(self.d_model)
        for position, input in enumerate(inputs):
            p = torch.log(1 + torch.tensor(position).requires_grad_(False))
            signal += torch.sin(2 * torch.pi * input * (t + p))
        signal = signal.requires_grad_(False)
        x = self.w1(signal)
        return x

text= """
In a quiet town where whispers play
Lives a creature night and day
A feline spirit soft and sly
Underneath the moonlit sky
With eyes like orbs of gleaming gold
Stories untold ancient and old
Paws that tread on silent ground
In their steps a mystery found
Whiskers twitch in the gentle breeze
Dancing lightly among the trees
Ears that listen to the night's song
In a world where they belong
Fur as soft as the morning's dew
In shades of black white or blue
They roam the streets without a care
Grace in each step light as air
In gardens lush and fields wide
Their elegant forms do glide
Masters of the shadow's dance
In their gaze you're caught in trance
By day they bask in sunlit beams
In slumber deep chasing dreams
Of mice that scamper in their play
In the realm of night and day
In ancient times they were revered
In pyramids their forms appeared
Guardians of the secrets old
In their eyes the stories told
In alleyways and on the fence
Their mystery makes perfect sense
A creature both wild and tame
Never twice quite the same
They purr like the rolling sea
A sound of peace and mystery
A lullaby for troubled hearts
In their presence warmth imparts
With agile leap and graceful bound
They traverse their hallowed ground
In every movement there's a poem
In every silence a hidden tome
In winter's chill or summer's heat
Their resilience is quite a feat
Adapting with such ease and grace
In every season they find their place
Some say they have nine lives to live
In each one love they freely give
Teachers of the art of being
In their gaze a deeper seeing
In their eyes a galaxy spins
A universe where wonder begins
Each whisker a line of a verse
In their world no need for rehearse
They play with yarn in sheer delight
In their joy the world turns bright
Chasing shadows pouncing on light
In their games a pure delight
At times they seem to ponder deep
Secrets in their hearts they keep
Sages in a furry guise
Wisdom old and worldly wise
"""

text = text.lower()
tokens = text.split(" ")
vocab = sorted(list(set(tokens)))
int2char = {(index + 1): char for index, char in enumerate(vocab)}
char2int = {char: (index + 1) for index, char in enumerate(vocab)}
encoded = [char2int[char] for char in tokens]

context_size = 4
train = [encoded[i:i+context_size] for i in range(len(encoded)-context_size)]
targets = [encoded[i+context_size] for i in range(len(encoded)-context_size)]

for item in range(len(train)):
    print(f"{' '.join([int2char[c] for c in train[item]])} {train[item]} -> {targets[item]} {int2char[targets[item]]}")

t = torch.linspace(0, 1, d_model).requires_grad_(False)
targets = [torch.sin(2*torch.pi*torch.tensor(target)*t) for target in targets]

model = MyModel(d_model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

for epoch in range(100):
    for i in range(len(train)):
        y = model(train[i])
        target = targets[i]
        loss = criterion(y, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} | Loss {loss.item()}")

while True:
    sentence = input("Enter a sentence: ")
    for i in range(300):
        context = sentence.lower()
        tokens = context.split(" ")
        context = tokens[-context_size:]
        encoded = [char2int[char] for char in context]
        y = model(encoded)
        fft = torch.fft.fft(y)
        fft = torch.abs(fft)[:len(fft)//2]
        prob = F.softmax(fft*0.5, dim=0)
        prediction = torch.multinomial(prob, num_samples=1).item()
        # print(prediction)
        # print(prediction, int2char[prediction])
        sentence += (" " + int2char[prediction])
    print(sentence)
hunar4321 commented 8 months ago

Very cool! I run your code and got very low loss which is very good👍
However, it's hard to compare the performance of this algorithm with what presented in the video tutorial because the data size is small and you are using words as tokens so it's likely that your model might over-fit. Try character level prediction and check how it performs. It the performance is still good, that means representing tokens as sin waves do really help. Best wishes.

mourad-ghafiri commented 8 months ago

I found today that there was a research on this method before :) https://www.youtube.com/watch?v=j7pWPdGEfMA https://www.youtube.com/watch?v=JJR3pBl78zw

mourad-ghafiri commented 8 months ago

Here is a modified code using sckitlearn regression on character based prediction, using shifted square signal representation (inspired from spikes of biological neurons)

https://github.com/mourad-ghafiri/FourierPositionalEmbeddings/blob/main/square_signals.py

import numpy as np
from sklearn.neural_network import MLPRegressor

text= """In a quiet town where whispers play
Lives a creature night and day
A feline spirit soft and sly
Underneath the moonlit sky
With eyes like orbs of gleaming gold
Stories untold ancient and old
Paws that tread on silent ground
In their steps a mystery found
Whiskers twitch in the gentle breeze
Dancing lightly among the trees
Ears that listen to the night's song
In a world where they belong
Fur as soft as the morning's dew
In shades of black white or blue
They roam the streets without a care
Grace in each step light as air
In gardens lush and fields wide
Their elegant forms do glide
Masters of the shadow's dance
In their gaze you're caught in trance
By day they bask in sunlit beams
In slumber deep chasing dreams
Of mice that scamper in their play
In the realm of night and day
In ancient times they were revered
In pyramids their forms appeared
Guardians of the secrets old
In their eyes the stories told
In alleyways and on the fence
Their mystery makes perfect sense
A creature both wild and tame
Never twice quite the same
They purr like the rolling sea
A sound of peace and mystery
A lullaby for troubled hearts
In their presence warmth imparts
With agile leap and graceful bound
They traverse their hallowed ground
In every movement there's a poem
In every silence a hidden tome
In winter's chill or summer's heat
Their resilience is quite a feat
Adapting with such ease and grace
In every season they find their place
Some say they have nine lives to live
In each one love they freely give
Teachers of the art of being
In their gaze a deeper seeing
In their eyes a galaxy spins
A universe where wonder begins
Each whisker a line of a verse
In their world no need for rehearse
They play with yarn in sheer delight
In their joy the world turns bright
Chasing shadows pouncing on light
In their games a pure delight
At times they seem to ponder deep
Secrets in their hearts they keep
Sages in a furry guise
Wisdom old and worldly wise"""

N = 1024
t = np.linspace(0, 1, N)

vocab = sorted(list(set(text)))
int2char = {(index + 1): char for index, char in enumerate(vocab)}
char2int = {char: (index + 1) for index, char in enumerate(vocab)}
encoded = [char2int[char] for char in text]

context_size = 8
train = [encoded[i:i+context_size] for i in range(len(encoded)-context_size)]
targets = [encoded[i+context_size] for i in range(len(encoded)-context_size)]

def token_to_signal(token, position=0):
    # representation of token as a shifted square signal (inspired from spikes in the biological neural networks)
    y = np.sign(np.sin(2*np.pi*token*(t + (position/N)*np.pi)))
    return y

def context_to_signal(context):
    # Attention is calculated as sum of the shifted square signals of the tokens in the context
    signal = np.zeros(N)
    for i in range(len(context)):
        signal += token_to_signal(context[i], i)
    return signal

for i in range(len(train)):
    print(train[i], [int2char[c] for c in train[i]], targets[i], int2char[targets[i]])

X = []
Y = []
for i in range(len(train)):
    X.append(context_to_signal(train[i]))
    Y.append(token_to_signal(targets[i]))

model = MLPRegressor(
    verbose=True,
    hidden_layer_sizes=(N,) * 4, 
    solver='adam', activation="relu",
    learning_rate_init=0.001, learning_rate= "adaptive",
    batch_size=32, shuffle=True,
    max_iter=10000, tol=0.000001, n_iter_no_change=10000,
    random_state=0
)

model.fit(X, Y)

while True:
    sentence = input("Enter a sentence: ")
    for i in range(500):
        context = sentence[-context_size:]
        encoded = [char2int[char] for char in context]
        context_signal = context_to_signal(encoded)
        y = model.predict([context_signal])
        y = np.array(y).reshape(-1)
        fft = np.fft.fft(y)
        fft = np.abs(fft)[:len(fft)//2]
        prediction = np.argmax(fft)
        try:
            sentence += int2char[prediction]
        except:
            pass
    print(sentence)