aju22 / RoPE-PyTorch

This repository contains an educational implementation of Rotary Positional Encodings (RoPE) in PyTorch. RoPE is a method introduced in the paper RoFormer: Enhanced Transformer with Rotary Position Embedding by Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu.
6 stars 0 forks source link

Request for a PyTorch-based Transformer Code with Integrated RoPE #1

Open Lucienxhh opened 3 months ago

Lucienxhh commented 3 months ago

I've recently been studying Transformer and learning through PyTorch code. I've come across a technique called RoPE that can enhance the performance of Transformer.

Consequently, I searched for your code repository on GitHub using the keywords "Rope in pytorch."

Unfortunately, it seems that you have only implemented RoPE without integrating it into a Transformer model.

Upon browsing through your other repositories, I noticed that you've used TensorFlow to implement Transformer, but the positional encoding is missing.

Without interrupting your work and life, I sincerely hope you could share a PyTorch-based Transformer code that incorporates RoPE. I have some code that might be of assistance to you.

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
import os
import torch.nn.functional as F
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        a = self.pos_embedding[:token_embedding.size(0), :]
        embedding = token_embedding + a
        return self.dropout(embedding)

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, nhead: int,
                 src_vocab_size: int, tgt_vocab_size: int, dim_feedforward: int = 512, dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def reset_parameter(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor,
                src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_padding_mask: Tensor):

        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor, src_padding_mask: Tensor):

        src_emb = self.positional_encoding(self.src_tok_emb(src))
        return self.transformer.encoder(src_emb, src_mask, src_padding_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor, memory_padding_mask: Tensor):

        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        return self.transformer.decoder(tgt_emb, memory, tgt_mask, tgt_padding_mask, memory_padding_mask)
viai957 commented 2 months ago

@Lucienxhh I too was looking into something similar. I did try to replace the Absolute Positional Embeddings with the RotaryPositionalEmbeddings provided by @aju22. I believe he has updated the repo to use PyTorch. I tried to replace the Positional Embedding class with RotaryPositionalEmbeddings. Later, I realized that the RopE class expects a 4D array as input, which is not the conventional structure, as they expect a 3D tensor parameter for the input. ` import torch import torch.nn as nn import math

class LayerNormalization(nn.Module):

def __init__(self, features: int, eps:float=10**-6) -> None:
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
    self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

def forward(self, x):
    # x: (batch, seq_len, hidden_size)
     # Keep the dimension for broadcasting
    mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
    # Keep the dimension for broadcasting
    std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
    # eps is to prevent dividing by zero or when std is very small
    return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

def forward(self, x):
    # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

def __init__(self, d_model: int, vocab_size: int) -> None:
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, d_model)

def forward(self, x):
    # (batch, seq_len) --> (batch, seq_len, d_model)
    # Multiply by sqrt(d_model) to scale the embeddings according to the paper
    return self.embedding(x) * math.sqrt(self.d_model)

class RotaryPositionalEmbeddings(nn.Module):

def init(self, d_model: int, base: int = 10_000):

super().__init__()
self.base = base
self.d_model = d_model
self.cos_cached = None
self.sin_cached = None

def _build_cache(self, x: torch.Tensor):

if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
  return

seq_len = x.shape[0]

theta = 1. / (self.base ** (torch.arange(0, self.d_model, 2).float() / self.d_model)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)

seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]

idx_theta = torch.einsum('n,d->nd', seq_idx, theta)  #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]

idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]

self.cos_cached = idx_theta2.cos()[:, None, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
self.sin_cached = idx_theta2.sin()[:, None, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]

def _neg_half(self, x: torch.Tensor):

d_2 = self.d_model // 2 #

return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]

def forward(self, x: torch.Tensor):

self._build_cache(x)

neg_half_x = self._neg_half(x)

x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]

return x_rope

class ResidualConnection(nn.Module):

    def __init__(self, features: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

def __init__(self, d_model: int, h: int, dropout: float) -> None:
    super().__init__()
    self.d_model = d_model # Embedding vector size
    self.h = h # Number of heads
    # Make sure d_model is divisible by h
    assert d_model % h == 0, "d_model is not divisible by h"

    self.d_k = d_model // h # Dimension of vector seen by each head
    self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
    self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
    self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
    self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
    self.dropout = nn.Dropout(dropout)

@staticmethod
def attention(query, key, value, mask, dropout: nn.Dropout):
    d_k = query.shape[-1]
    # Just apply the formula from the paper
    # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
    attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # Write a very low value (indicating -inf) to the positions where mask == 0
        attention_scores.masked_fill_(mask == 0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
    if dropout is not None:
        attention_scores = dropout(attention_scores)
    # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
    # return attention scores which can be used for visualization
    return (attention_scores @ value), attention_scores

def forward(self, q, k, v, mask):
    query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
    key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
    value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

    # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

    # Calculate attention
    x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

    # Combine all the heads together
    # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
    x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

    # Multiply by Wo
    # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
    return self.w_o(x)

class EncoderBlock(nn.Module):

def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

def forward(self, x, src_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)
    return x

class Encoder(nn.Module):

def __init__(self, features: int, layers: nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

def forward(self, x, mask):
    for layer in self.layers:
        x = layer(x, mask)
    return self.norm(x)

class DecoderBlock(nn.Module):

def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

def forward(self, x, encoder_output, src_mask, tgt_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)
    return x

class Decoder(nn.Module):

def __init__(self, features: int, layers: nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
        x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

class ProjectionLayer(nn.Module):

def __init__(self, d_model, vocab_size) -> None:
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

def forward(self, x) -> None:
    # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
    return self.proj(x)

class Transformer(nn.Module):

def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: RotaryPositionalEmbeddings, tgt_pos: RotaryPositionalEmbeddings, projection_layer: ProjectionLayer) -> None:
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer

def encode(self, src, src_mask):
    # (batch, seq_len, d_model)
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src, src_mask)

def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
    # (batch, seq_len, d_model)
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

def project(self, x):
    # (batch, seq_len, vocab_size)
    return self.projection_layer(x)

def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:

Create the embedding layers

src_embed = InputEmbeddings(d_model, src_vocab_size)
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

# Create the positional encoding layers
src_pos = RotaryPositionalEmbeddings(d_model, src_seq_len)
tgt_pos = RotaryPositionalEmbeddings(d_model, tgt_seq_len)

# Create the encoder blocks
encoder_blocks = []
for _ in range(N):
    encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
    encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
    encoder_blocks.append(encoder_block)

# Create the decoder blocks
decoder_blocks = []
for _ in range(N):
    decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
    feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
    decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
    decoder_blocks.append(decoder_block)

# Create the encoder and decoder
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

# Create the projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

# Create the transformer
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

# Initialize the parameters
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

return transformer

` image image