windsornguyen / flash-stu

PyTorch implementation of the Spectral Transform Unit.
https://arxiv.org/abs/2409.10489/
Apache License 2.0
9 stars 2 forks source link

my pytorch port of STU #1

Closed kashif closed 2 weeks ago

kashif commented 2 months ago

Thanks for the repo, here is my pytorch port of the STU layer with an additional step function that might be useful, as asked https://github.com/google-deepmind/spectral_ssm/issues/1

import torch
import torch.nn as nn
import numpy as np

class STU(nn.Module):
    """Simple STU Layer in PyTorch with support for d_in != d_out."""

    def __init__(
        self,
        d_in: int = 256,
        d_out: int = 256,
        input_len: int = 1024,
        num_eigh: int = 24,
        auto_reg_k_u: int = 3,
        auto_reg_k_y: int = 2,
        learnable_m_y: bool = True,
    ) -> None:
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.input_len = input_len
        self.eigh = self.get_top_hankel_eigh(input_len, num_eigh)
        self.k = num_eigh
        self.auto_reg_k_u = auto_reg_k_u
        self.auto_reg_k_y = auto_reg_k_y
        self.learnable_m_y = learnable_m_y
        self.m_x_var = 1.0 / (float(self.d_out) ** 0.5)

        # Initialize parameters
        self.init_m_y = nn.Parameter(
            torch.zeros(self.d_out, self.auto_reg_k_y, self.d_out),
            requires_grad=learnable_m_y,
        )
        # Initialize m_u using trunc_normal_ and scaled with m_x_var
        m_u = nn.Parameter(
            nn.init.trunc_normal_(torch.empty(self.d_out, self.d_in, self.auto_reg_k_u))
        )
        self.m_u = m_u * self.m_x_var

        self.m_phi = nn.Parameter(torch.zeros(self.d_in * self.k, self.d_out))

        # Initialize state
        self.reset_state(batch_size=1)

    def reset_state(self, batch_size: int) -> None:
        """Reset the state for a new batch."""
        self.state = {
            "y": torch.zeros(batch_size, self.auto_reg_k_y, self.d_out),
            "x": torch.zeros(batch_size, self.auto_reg_k_u, self.d_in),
        }

    def get_state(self) -> dict:
        """Get the current state."""
        return {k: v.clone() for k, v in self.state.items()}

    def set_state(self, state: dict) -> None:
        """Set the current state."""
        self.state = {k: v.clone() for k, v in state.items()}

    def get_top_hankel_eigh(self, n: int, k: int) -> tuple[torch.Tensor, torch.Tensor]:
        """Get top k eigenvalues and eigenvectors of spectral Hankel matrix."""

        def get_hankel_matrix(n: int) -> np.ndarray:
            z = np.zeros((n, n))
            for i in range(1, n + 1):
                for j in range(1, n + 1):
                    z[i - 1, j - 1] = 2 / ((i + j) ** 3 - (i + j))
            return z

        hankel_matrix = get_hankel_matrix(n)
        eig_vals, eig_vecs = np.linalg.eigh(hankel_matrix)
        return (
            torch.from_numpy(eig_vals[-k:]).float(),
            torch.from_numpy(eig_vecs[:, -k:]).float(),
        )

    def compute_x_tilde(self, inputs: torch.Tensor) -> torch.Tensor:
        """Project input sequence into spectral basis."""
        eig_vals, eig_vecs = self.eigh
        b, l, _ = inputs.shape

        # Use only the relevant part of eig_vecs
        eig_vecs_trunc = eig_vecs[:l, :]

        # Compute convolution
        x_tilde = torch.einsum("lk,bld->bkld", eig_vecs_trunc, inputs)

        # Apply eigenvalue scaling
        x_tilde *= eig_vals.view(1, -1, 1, 1) ** 0.25

        # Reshape
        x_tilde = x_tilde.reshape(b, l, -1)

        # Shift
        x_tilde = torch.roll(x_tilde, shifts=2, dims=1)
        x_tilde[:, :2] = 0  # Zero out the first two elements

        return x_tilde

    def compute_ar_x_preds(self, x: torch.Tensor) -> torch.Tensor:
        """Compute the auto-regressive component of spectral SSM."""
        b, l, _ = x.shape
        o = torch.einsum("oik,bli->bklo", self.m_u, x)

        # Roll and mask
        o = torch.stack(
            [torch.roll(o[:, i], shifts=i, dims=1) for i in range(self.auto_reg_k_u)],
            dim=1,
        )
        mask = (
            torch.triu(torch.ones(self.auto_reg_k_u, l))
            .unsqueeze(0)
            .unsqueeze(-1)
            .to(x.device)
        )

        return torch.sum(o * mask, dim=1)

    def compute_y_t(self, m_y: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
        """Compute sequence of y_t given a series of deltas and m_y."""
        b, l, _ = deltas.shape
        ys = []

        for i in range(l):
            output = torch.einsum("oky,bky->bo", m_y, self.state["y"]) + deltas[:, i]
            ys.append(output)
            self.state["y"] = torch.roll(self.state["y"], 1, dims=1)
            self.state["y"][:, 0] = output

        return torch.stack(ys, dim=1)

    def forward(
        self, inputs: torch.Tensor, initial_state: dict = None
    ) -> tuple[torch.Tensor, dict]:
        """Forward pass with support for variable input lengths up to input_len."""
        b, l, d_in = inputs.shape
        assert (
            d_in == self.d_in
        ), f"Input dimension {d_in} does not match expected dimension {self.d_in}"
        assert (
            l <= self.input_len
        ), f"Input sequence length {l} exceeds maximum length {self.input_len}"

        # Set initial state if provided, otherwise reset
        if initial_state is not None:
            self.set_state(initial_state)
        else:
            self.reset_state(batch_size=b)

        # Pad input if necessary
        if l < self.input_len:
            pad_length = self.input_len - l
            inputs_padded = torch.nn.functional.pad(inputs, (0, 0, 0, pad_length))
        else:
            inputs_padded = inputs

        x_tilde = self.compute_x_tilde(inputs_padded)
        delta_phi = torch.einsum("blk,ko->blo", x_tilde, self.m_phi)
        delta_ar_u = self.compute_ar_x_preds(inputs_padded)
        output = self.compute_y_t(self.init_m_y, delta_phi + delta_ar_u)

        # Update state['x']
        self.state["x"] = torch.roll(self.state["x"], shifts=-l, dims=1)
        self.state["x"][:, -l:] = inputs[:, -self.auto_reg_k_u :]

        # Return only the non-padded output and the final state
        return output[:, :l, :], self.get_state()

    def step(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
        """Auto-regressive step function using current state."""
        b, d_in = x.shape
        assert (
            d_in == self.d_in
        ), f"Input dimension {d_in} does not match expected dimension {self.d_in}"
        assert b == self.state["x"].shape[0], "Batch size mismatch with current state"

        # Update state['x']
        self.state["x"] = torch.roll(self.state["x"], shifts=-1, dims=1)
        self.state["x"][:, -1] = x

        # Compute x_tilde for the current step
        x_expanded = x.unsqueeze(1)  # Shape: (b, 1, d_in)
        x_tilde = self.compute_x_tilde(x_expanded)[:, 0]  # Shape: (b, k * d_in)

        # Compute delta_phi
        delta_phi = torch.matmul(x_tilde, self.m_phi)

        # Compute delta_ar_u
        delta_ar_u = torch.einsum("oik,bki->bo", self.m_u, self.state["x"])

        # Compute y_t
        y_t = (
            torch.einsum("oky,bky->bo", self.init_m_y, self.state["y"])
            + delta_phi
            + delta_ar_u
        )

        # Update state['y']
        self.state["y"] = torch.roll(self.state["y"], shifts=1, dims=1)
        self.state["y"][:, 0] = y_t

        return y_t, self.get_state()
windsornguyen commented 1 month ago

Hi @kashif - thank you for your port!

Flash STU omits the autoregressive component from the original STU paper and relies solely on the spectral component. We found the autoregressive part to be sometimes helpful in terms of performance, but a bit too slow for our liking.

Here is an example of what an inference script for Flash STU could look like:

import tiktoken
import torch
from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters
from safetensors import safe_open

tokenizer = tiktoken.get_encoding('o200k_base')
prompt = "Hi, my name is"
device = torch.device('cuda')

def generate_text(model, tokenizer, prompt, num_return_sequences=5, max_length=1024, device='cuda', temperature=1.0, top_k=50):
    model.eval()
    tokens = torch.tensor([tokenizer.encode(prompt, allowed_special={'<|endoftext|>'})], device=device)
    tokens = tokens.repeat(num_return_sequences, 1)

    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(1337)

    eos_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]

    with torch.no_grad():
        for _ in range(max_length - tokens.size(1)):
            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
                logits = model(tokens)
                logits = logits[:, -1, :]  # Get logits for the last token

                # Apply temperature scaling if temperature > 0
                if temperature > 0:
                    logits = logits / temperature

            probs = torch.nn.functional.softmax(logits, dim=-1)  # Compute probabilities

            # Top-K sampling: set all probabilities outside the top K to 0
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            ix = torch.multinomial(top_k_probs, 1, generator=sample_rng)
            next_token = torch.gather(top_k_indices, -1, ix)
            tokens = torch.cat((tokens, next_token), dim=1) # The autoregressive part!

            # Break if EOS token is generated
            if (next_token == eos_token_id).any():
                break

    generated_sequences = []
    for i in range(num_return_sequences):
        decoded = tokenizer.decode(tokens[i].tolist())
        generated_sequences.append(decoded)

    return generated_sequences

config = FlashSTUConfig()
phi = get_spectral_filters(
   seq_len, 
   num_eigh, 
   use_hankel_L, 
   device, 
   torch.bfloat16,
)
model = FlashSTU(config, phi)

state_dict = {}
with safe_open('model.safetensors', framework="pt", device='cuda') as f:
    for k in f.keys():
        state_dict[k] = f.get_tensor(k)

model.load_state_dict(state_dict)
model.to(device)

generated_texts = generate_text(model, tokenizer, prompt, num_return_sequences=5, max_length=1024)
for i, text in enumerate(generated_texts):
    print(f"Sample {i + 1}: {text}\n")

This will generate until (1) the EOS token is produced or (2) the maximum sequence length is reached. Is this what you were looking for in terms of inference?

kashif commented 1 month ago

Thanks! Checking your way!

kashif commented 1 month ago

@windsornguyen as a slight-aside what are the key differences in the original STU implementation and the one in flash-STU in just the STU module?

windsornguyen commented 1 month ago

@kashif

  1. Autoregressive component omitted
  2. Spectral filters scaled with the eigenvalues immediately instead of scaling after the convolution
  3. Tensordot approximation, per the paper (2409.10489)
  4. Default to Flash FFT for convolutions
kashif commented 1 month ago

thanks!