berlino / gated_linear_attention

MIT License
97 stars 2 forks source link

Worse performance with subchunking #2

Closed faresobeid closed 10 months ago

faresobeid commented 10 months ago

I have been experimenting with GLA and implementing it in pure pytorch or Jax (no triton) and have run into performance problems with subchunking. When using the chunkwise method on its own, I get promising results in terms of speed and performance.

However to make it work, I have to add an epsilon k/(A + 1e-12) or else k will explode very quickly which is expected.

To avoid epsilon, I tried subchunking (almost the same as the pseudocode) however although speed is not a big problem, the performance becomes significantly worse.

The subchunking performance ends up being much worse than both chunkwise and recurrent.

Is there any reason for this and how could this be fixed?

PS: The main purpose of this is to hopefully use the GLA methods for a no cuda or triton RWKV v6

Code for subchunking in Jax:

def parallel_subchunk(self,Q,K,V,A):
        C = 16
        L,D = Q.shape
        Q,K,V,A = map(lambda x: x.reshape((L//C,C,-1)),[Q,K,V,A])
        out = jnp.zeros((L//C,C,D))
        for i in range(L//C):
            o = jnp.zeros((C,D))
            q = Q[i]
            a_q = A[i]
            a_normalizer = a_q[0]
            q = q * jnp.exp(a_q-a_normalizer[None])
            for j in range(i):
                k = K[j]
                v = V[j]
                a_kv = A[j]
                k = k * jnp.exp(a_normalizer[None]-a_kv)
                o += (q @ k.T) @ v
            k = K[i]
            v = V[i]
            k = k * jnp.exp(a_normalizer[None]-a_q)
            o += jnp.tril(q @ k.T) @ v
            out.at[i].set(o)
        return out.reshape((L,D))
sustcsonglin commented 10 months ago

Thanks for your interest! How worse the performance would be? Is L the first-level chunk size or the entire sequence length? Have you done any sanity check? like comparing the outputs and gradients with chunkwise or recurrent implementation? From the snippet itself I did not see problems and believed it is mathmatically equivalent to other two forms. My only concern on the code snippet is that

k = K[i]
v = V[i]
k = k * jnp.exp(a_normalizer[None]-a_q)
o += jnp.tril(q @ k.T) @ v

could potentially explode due to the same reason you described.

faresobeid commented 10 months ago

Thank you for the quick reply! I've checked and the outputs are equivalent between all implementations (chunkwise, recurrent, ...), although I havent checked for gradients. I'm using the code above as a part of my chunking so L would be the chunk size and C is the subchunk size. I haven't tried the log semiring implementation for that so I'll try and respond very soon if thats fine.

sustcsonglin commented 10 months ago

Ah this seems weirded. If the forward pass gives the same result, and the autograd works as normal, then i guess the gradient should be the same. What is your numerical precision, half or single? If is half, you can use single or even double precision for the intra-sub-chunk computation as a workaround to avoid log semiring implementation as the subchunk size itself is very small.

faresobeid commented 10 months ago

I'm using TPU's so from my understanding it will be done in bfloat16, I can try single or double precision but would you expect the performance loss to be from this? The performance difference is around 0.4 (cross entropy loss) with vocab size = 98. This is from early in training but it seems like that gap stays throughout training.

sustcsonglin commented 10 months ago

actually I use fp32 in my triton based implementation.

Yes this will negatively influence the running speed because tensor cores cannot do fp32 matmul (at least on Nvidia GPUs). But my sense is that the cost is relatively small because the subchunk size is as small as 16. And this is exactly the motivation for me to use subchunking: we can use bf16 matmul between different subchunks for most computation.

faresobeid commented 10 months ago

I just tried float32 and had no real improvements, although it should be more numerically stable.

faresobeid commented 10 months ago

I can provide my full code if you want, it is different than the full architecture in the paper, I’m using the rwkv architecture but that shouldn’t make the difference

sustcsonglin commented 10 months ago

sure I can help inspect the full code

faresobeid commented 10 months ago
# -*- coding: utf-8 -*-
"""rwkv-v6-best (17).ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1APJXUus7xotOQbITxoJsRPXjYKosFg8Q
"""

!pip install --upgrade -q pip
!pip install -q --upgrade jax[tpu] jaxlib
!pip install -q equinox
!pip install -q --upgrade matplotlib
import jax
import equinox as eqx
import equinox.nn as nn
import math
from jax import lax, random, numpy as jnp
import jax.nn as jnn
import optax
from itertools import repeat
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import time
import functools as ft
from typing import Callable, Any, Optional,Tuple, List
from jaxtyping import Array, Float
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as sharding
from jax import custom_vjp
from jax.experimental import pallas as pl

data = np.memmap('/kaggle/input/tinystories-newest/tinystories.bin', dtype=np.uint8, mode='r')
n = int(len(data)*0.95)
train_data = data[:n]
val_data = data[n:]
chars = ['<EOS>', '\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
vocab_size = len(chars)
ctoi = {c: i for i, c in enumerate(chars)}
itoc = {i: c for i, c in enumerate(chars)}
def encode(x):
    x = x.split(' ')
    tokens = []
    for i,c in enumerate(x):
        if c == '<EOS>':
            tokens.append(ctoi[c])

        else:
            for ch in c:
                tokens.append(ctoi[ch])
            if i != len(x)-1:
                tokens.append(ctoi[' '])
    return {'input_ids': bytes(tokens), 'len': len(tokens)}
decode = lambda x: ''.join([itoc[i] for i in x])

def get_batch(batch_size,block_size,split):
    data = train_data if split == 'train' else val_data
    ix = np.random.randint(len(data) - block_size, size=(batch_size,))
    x = np.stack([data[i: i + block_size] for i in ix])
    y = np.stack([data[i + 1: i + block_size + 1] for i in ix])
    return x,y

class Linear(eqx.Module):
    weight: Array
    def __init__(self,in_features,out_features,key,scale=1.):
        if out_features > in_features:
            gain = math.sqrt(out_features / in_features)
        else:
            gain = 1.
        self.weight = jnn.initializers.orthogonal(gain*scale)(key,(out_features, in_features))
    def __call__(self, x):
        x = self.weight @ x
        return x

class RMSNorm(eqx.Module):
    dim: int = eqx.static_field()
    use_weight: bool = eqx.static_field()
    weight: Optional[Array]
    def __init__(self, dim, use_weight=True):
        self.dim = dim
        self.use_weight = use_weight
        self.weight = jnp.ones(dim) if use_weight else None
    def __call__(self,x):
        inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + 1e-5)
        out = jnp.sqrt(self.dim) * inv_rms * x
        if self.use_weight:
            out = self.weight * out
        return out

def time_shift(x,sx):
    return jnp.concatenate((sx[None], x[:-1, :])),x[-1,:]

class ChannelMix(eqx.Module):
    ln: RMSNorm
    Wk: Linear
    Wr: Linear
    Wv: Linear
    mu_k: Array
    mu_r: Array
    def __init__(self,L_id,L,D,key):
        keys = random.split(key,3)
        self.ln = RMSNorm(D)
        self.Wk = Linear(D,int(3.5*D),keys[0])
        self.Wr = Linear(D,D,keys[1])
        self.Wv = Linear(int(3.5*D),D,keys[2],scale=0)
        ratio_0_to_1 = L_id / (L - 1)
        ratio_1_to_almost0 = 1.0 - (L_id / L)
        ddd = (jnp.arange(D)/D)[None]
        self.mu_k = jnp.power(ddd, ratio_1_to_almost0)
        self.mu_r = jnp.power(ddd, ratio_1_to_almost0)
    def __call__(self,x,state):
        xx = jax.vmap(self.ln)(x)
        sx,state = time_shift(xx,state)
        sx = sx - xx
        kx = xx + sx * self.mu_k
        rx = xx + sx * self.mu_r
        k = jax.vmap(self.Wk)(kx)
        r = jax.vmap(self.Wr)(rx)
        k = jnp.square(jnn.relu(k))
        out = jnn.sigmoid(r) * jax.vmap(self.Wv)(k)
        return out + x,state

class TimeMix(eqx.Module):
    ln: RMSNorm
    gn: nn.GroupNorm
    x_maa: Array
    r_maa: Array
    w_maa: Array
    k_maa: Array
    v_maa: Array
    g_maa: Array
    tm_w1: Array
    tm_w2: Array
    td_w1: Array
    td_w2: Array
    w: Array
#     u: Array
    Wk: Linear
    Wv: Linear
    Wr: Linear
    Wg: Linear
    Wo: Linear
    H: int = eqx.static_field()
    C: int = eqx.static_field()

    def __init__(self,L_id,L,H,D,key):
        keys = random.split(key,7)
        self.ln = RMSNorm(D)
        self.gn = nn.GroupNorm(H,D)
        ratio_0_to_1 = L_id / (L - 1)
        ratio_1_to_almost0 = 1.0 - (L_id / L)
        ddd = (jnp.arange(D)/D)[None]
        self.x_maa = jnp.power(ddd, ratio_1_to_almost0)
        self.r_maa = jnp.power(ddd, ratio_1_to_almost0)
        self.w_maa = jnp.power(ddd, ratio_1_to_almost0)
        self.k_maa = jnp.power(ddd, ratio_1_to_almost0)
        self.v_maa = jnp.power(ddd, ratio_1_to_almost0)
        self.g_maa = jnp.power(ddd, ratio_1_to_almost0)
        TIME_MIX_EXTRA_DIM = 32
        self.tm_w1 = random.uniform(keys[0],(D, TIME_MIX_EXTRA_DIM * 5),minval=-0.01,maxval=0.01)
        self.tm_w2 = jnp.zeros((5, TIME_MIX_EXTRA_DIM, D))
        W_MIX_EXTRA_DIM = 64
        self.td_w1 = random.uniform(keys[1],(D, W_MIX_EXTRA_DIM),minval=-0.01,maxval=0.01)
        self.td_w2 = jnp.zeros((W_MIX_EXTRA_DIM, D))
        self.w = (-6 + 5 * (jnp.arange(D) /(D - 1)) ** (0.7 + 1.3 * ratio_0_to_1)).reshape((1,H,-1))
#         n = jnp.arange(D)
#         zigzag = ((n + 1) % 3 - 1) * 0.1
#         self.u = (ratio_0_to_1 * (1 - (n / max(D - 1, 1))) + zigzag).reshape((H, -1, 1))
        self.Wk = Linear(D,D,keys[2])
        self.Wv = Linear(D,D,keys[3])
        self.Wr = Linear(D,D,keys[4])
        self.Wg = Linear(D,D,keys[5])
        self.Wo = Linear(D,D,keys[6],0)
        self.H = H
        self.C = 128

    def rwkv(self,k,v,r,w,s):
        w = jnp.exp(-jnp.exp(w))
        k,v,r = map(lambda x: x[:,None,:],[k,v,r])
        w = w[:,:,None]
        def loop(s,W_t):
            k_t,v_t,r_t,w_t = W_t
            s = w_t * s + k_t.T @ v_t
            out_t = r_t @ s
            return s,out_t.squeeze(0)
        s,out = jax.lax.scan(loop,s,(k,v,r,w))
        return s,out
#     def parallel_log(self,Q,K,A):

    def parallel_subchunk(self,Q,K,V,A):
        C = 16
        L,D = Q.shape

        Q,K,V,A = map(lambda x: x.reshape((L//C,C,-1)),[Q,K,V,A])

        out = jnp.zeros((L//C,C,D))
        for i in range(L//C):
            o = jnp.zeros((C,D))
            q = Q[i]
            a_q = A[i]
            a_normalizer = a_q[0]
            q = q * jnp.exp(a_q-a_normalizer[None])
            for j in range(i):
                k = K[j]
                v = V[j]
                a_kv = A[j]
                k = k * jnp.exp(a_normalizer[None]-a_kv)
                o += (q @ k.T) @ v
            k = K[i]
            v = V[i]
            k = k * jnp.exp(a_normalizer[None]-a_q)
            o += jnp.tril(jnp.matmul(q, k.T, precision=jax.lax.Precision('float32'))) @ v
            out.at[i].set(o)
        return out.reshape((L,D))

    def chunkwise(self,q,k,v,w):
        w = -jnp.exp(w)
        C = self.C
        L,D_q = q.shape
        D_v = v.shape[-1]
        q,k,v,w = map(lambda x: x.reshape((L//C,C,-1)),[q,k,v,w])
        A_1 = jnp.sum(w,1)[:,None]
        A_2 = jax.lax.cumsum(w,axis=1,reverse=True)-w
        A_3 = jnp.cumsum(w,axis=1)
        s = jnp.zeros((D_q,D_v))

        def loop(s,W_t):
            q_t,k_t,v_t,A_1_t,A_2_t,A_3_t,w_t = W_t
            s_new = jnp.exp(A_1_t).T * s + (jnp.exp(A_2_t) * k_t).T @ v_t
            q_ = q_t * jnp.exp(A_3_t)
#             o = q_ @ s + self.parallel_subchunk(q_t,k_t,v_t,A_3_t)
            o = q_ @ s + jnp.tril(q_ @ (k_t / (jnp.exp(A_3_t)+1e-12)).T) @ v_t
            return s_new,o
        _,out = jax.lax.scan(loop,s,(q,k,v,A_1,A_2,A_3,w))
        return out.reshape((L,D_v))

    def __call__(self,x,x_state,s,output_state):
        T,D = orig_shape = x.shape
        xx = jax.vmap(self.ln)(x)
        sx,x_state = time_shift(xx,x_state)
        sx = sx - xx
        xxx = xx + sx * self.x_maa
        xxx = jnn.tanh(xxx @ self.tm_w1).reshape((T, 5, -1)).swapaxes(0, 1)
        mw, mk, mv, mr, mg = (xxx @ self.tm_w2).reshape((5, T, -1))
        wx = xx + sx * (self.w_maa + mw)
        kx = xx + sx * (self.k_maa + mk)
        vx = xx + sx * (self.v_maa + mv)
        rx = xx + sx * (self.r_maa + mr)
        gx = xx + sx * (self.g_maa + mg)
        w = self.w + (jnn.tanh(wx @ self.td_w1) @ self.td_w2).reshape((T,self.H,-1))
        k = jax.vmap(self.Wk)(kx)
        v = jax.vmap(self.Wv)(vx)
        r = jax.vmap(self.Wr)(rx)
        g = jax.vmap(self.Wg)(gx)
        k,v,r,w = map(lambda x: x.reshape((T,self.H,-1)).swapaxes(0,1),[k,v,r,w])
        g = jnn.silu(g)
        if output_state:
            s,out = jax.vmap(self.rwkv)(k,v,r,w,s)
        else:
            s = None
            out = jax.vmap(self.chunkwise)(r,k,v,w)
        out = out.swapaxes(0,1).reshape(orig_shape)
        out = jax.vmap(self.gn)(out)
        out = out.astype(x.dtype) * g
        out = jax.vmap(self.Wo)(out)
        return x + out,x_state,s

class Block(eqx.Module):
    time_mix: TimeMix
    channel_mix: ChannelMix
    empty_x_state: Array = eqx.static_field()
    empty_attn_state: Array = eqx.static_field()
    empty_ffn_state: Array = eqx.static_field()
    def __init__(self,L_id,L,H,D,key):
        keys = random.split(key,2)
        self.time_mix = TimeMix(L_id,L,H,D,keys[0])
        self.channel_mix = ChannelMix(L_id,L,D,keys[1])
        self.empty_x_state = jnp.zeros(D)
        self.empty_attn_state = jnp.zeros((H,D//H,D//H))
        self.empty_ffn_state = jnp.zeros(D)
    def __call__(self, x, block_state, output_state):
        if block_state is None:
            block_state = (self.empty_x_state,self.empty_attn_state,self.empty_ffn_state)
        x_state, attn_state, ffn_state = block_state
        x, x_state, attn_state = self.time_mix(x, x_state, attn_state, output_state)
        x, ffn_state = self.channel_mix(x, ffn_state)
        return x,(x_state, attn_state, ffn_state)

class RWKV(eqx.Module):
    emb: Array
    ln_b: RMSNorm
    blocks: List[Block]
    ln_f: RMSNorm
    final: Linear
    def __init__(self,L,H,D,V,key):
        keys = random.split(key,3)
        keys_blocks = random.split(keys[0],L)
        self.emb = jnn.initializers.orthogonal(1e-4)(keys[1],(V,D))
        self.ln_b = RMSNorm(D)
        self.blocks = [Block(i,L,H,D,key) for i,key in enumerate(keys_blocks)]
        self.ln_f = RMSNorm(D)
        self.final = Linear(D,V,keys[2],0.5)
    def __call__(self,x, block_states=None, output_state=None):

        x = jax.vmap(lambda x: self.emb[x])(x)
        x = jax.vmap(self.ln_b)(x)
        next_states = []
        if block_states is None:
            block_states = repeat(None)
        for block, state in zip(self.blocks, block_states):
            x,new_state = block(x, state, output_state)
            if output_state:
                next_states.append(new_state)
        x = jax.vmap(self.ln_f)(x)
        x = jax.vmap(self.final)(x)
        if output_state:
            return x, next_states
        return x
    def generate(self,input_tokens,max_len,key):
        input_tokens = input_tokens.astype(jnp.int32)
        keys = random.split(key,max_len)
        _,new_states = self.__call__(input_tokens[:-1].reshape((-1,)),None,True)
        next_token = input_tokens[-1]
        all_tokens = jnp.empty((max_len),dtype=jnp.int32)
        def loop(i,carry):
            next_token,new_states,all_tokens = carry
            logits,new_states = self.__call__(next_token.reshape((-1,)),new_states,True)
            next_token_logits = logits[-1]/0.9
            top_logits, top_tokens = jax.lax.top_k(next_token_logits, min(40, len(next_token_logits)))
            token_idx = jax.random.categorical(keys[i], top_logits)
            next_token = top_tokens[token_idx]
            all_tokens = all_tokens.at[i].set(next_token)
            return next_token,new_states,all_tokens
        _,_,all_tokens = jax.lax.fori_loop(0,max_len,loop,(next_token,new_states,all_tokens))
        print(decode(all_tokens.tolist()))

@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(pred_y, y).mean()
    return loss

@eqx.filter_jit(donate="all")
def make_step(model, x, y, opt_state, opt_update):
    loss,grads = loss_fn(model, x, y)
    updates, opt_state = opt_update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

def get_shard():
    num_devices = len(jax.devices())
    devices = mesh_utils.create_device_mesh((num_devices, 1))
    shard = sharding.PositionalSharding(devices)
    return shard
def print_avg_loss(losses,n):
    loss = losses[-n:]
    sum_ = sum(loss)
    mean = sum_/len(loss)
    print('Avg of last ' + str(n) + ' losses: ' + str(mean))
def print_params(model):
    param_count = sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array)))
    print('Number of Parameters: ' + str(param_count/1e6) + ' (M)')

def train(seed):
    N = 10000
    B = 128
    L = 4
    H = 4
    D = 128
    V = vocab_size
    S = 1024
    sample_len = 1024
    LR = 0.01
    key = random.PRNGKey(seed)
    keys = random.split(key,3)
    key_sampling = keys[0]
    model = RWKV(L,H,D,V,keys[1])
    print_params(model)
    schedule = optax.linear_onecycle_schedule(N,LR)
    opt = optax.chain(optax.clip_by_global_norm(1.),optax.adamw(schedule,b1=0.9,b2=0.98,weight_decay=0.))
    opt_state = opt.init(eqx.filter(model, eqx.is_array))
    shard = get_shard()
    losses = []
    pbar = tqdm(range(N),desc='training:', position=0, leave=True)
    for i in pbar:
        x,y = get_batch(B,S,'train')
        x,y = jax.device_put((x,y),shard)
        model, opt_state, loss = make_step(model, x, y, opt_state, opt.update)
        if jnp.isnan(loss):
            print(i,'nans!!!!!!')
            break
        pbar.set_postfix(train_loss=loss.item())
        losses.append(np.asarray(loss))
        if i % (N//10) == 0 or (i == N - 1):
            print_avg_loss(losses,10)
            x,_ = get_batch(B,S,'val')
            prime = x[0][:50]
            prime_str = decode(prime.tolist())
            print(prime_str, "\n", "*" * 40)
            key_sample = random.fold_in(key_sampling,i)
            start = time.time()
            model.generate(prime,sample_len,key_sample)
            end = time.time()
            print('Time Taken: ' + str(end-start))
    plt.plot(np.log(losses))
    plt.show()
    return losses
faresobeid commented 10 months ago

Sorry for the messy code!

sustcsonglin commented 10 months ago

(Though not related to the subchunking)

s_new = jnp.exp(A_1_t).T * s + (jnp.exp(A_2_t) * k_t).T @ v_t should be s_new = jnp.exp(A_1_t) * s + (jnp.exp(A_2_t) * k_t).T @ v_t ?

faresobeid commented 10 months ago

Oh ok I just thought from the paper, alpha was 1xd, so making beta = 1, then G would be alpha.T which is d x 1. Thank you for pointing it out though!

sustcsonglin commented 10 months ago

A_1 = jnp.sum(w,1)[:,None] "None" will make the last dimension size 1 and if do the transpose you will have the gating in the V dimension instead of the K dimension?

faresobeid commented 10 months ago

W is shape L//C x C x D, A_1 = jnp.sum(w,1)[:,None] will make A_1 shape L//C x 1 x D, so A_1_t will be of shape 1 x D. So if im not mistaken the transpose would have gating in the K dimension which fits in with the paper. If you replace D with D_k then you would have to transpose to multiply by the state of shape D_k x D_v.

sustcsonglin commented 10 months ago

Sorry im not familiar with Jax at all. my guess is jnp.sum(w,1) will give L//C x D, and then [:, None] will make the resulting size L//C x D x 1. Why is L//C x 1 x D. Is this specific for Jax?

faresobeid commented 10 months ago

From my understanding if w is shape L//C x D, then w[:,None] is the same as w[:,None,:] which is L//C x 1 x D.

sustcsonglin commented 10 months ago

ah you are right. Just misread it as `[..., None] ' ...

I read the code throughout and it looks good for me. Still have no idea why subchunking failed. Maybe need some gradient checking. Given that i am not familiar with Jax, i cannot help debug further.

faresobeid commented 10 months ago

That’s fine, I can try and rewrite this into PyTorch to see if there is any difference and maybe then could you help debug?

sustcsonglin commented 10 months ago

Sure!

faresobeid commented 10 months ago

Would it be better to remove all the rwkv specific stuff for debugging? It would also be easier to code

sustcsonglin commented 10 months ago

yep! i believe so. Here is my previous debugging file for the triton implementation: link and everything looks good.

faresobeid commented 10 months ago

As an update I realised that the line out.at[i].set(o) should be out = out.at[i].set(o) so out stayed as zero, so it was a stupid mistake. Now the performance is much better however very unstable even at float32, I will try using the log semiring to fix that and update you on how that goes. If possible, do you know an optimal implementation for that? Either way, thank you for helping and sorry for wasting your time for a stupid mistake.

sustcsonglin commented 10 months ago

Ah good to hear this! Fp32 without regularization on the decay rate is prone to numerical issue! My suggestion on the next step:

faresobeid commented 10 months ago

Thank you for the reply!

  1. Ill try this as soon as possible
  2. I can try this but fp64 is not recommended on Jax and tpus
  3. I tried this but it was too slow unfortunately
  4. Thank you for the cuda code, ill try that as well Just as a point, subchunking is much more stable than chunkwise (without epsilon) which is a good sign.
sustcsonglin commented 10 months ago

Yes. This is the partial motivation of subchunking! If you use a small first-level chunk size, you have better numerical stability but you need to materialize more states into HBMs -> slow. If you use a large first-level chunk size, the numerical issue will be a nightmare. Subchunking avoids the two issues simultaneously. In my experience, fp32 is enough to have a good numerical stability for subchunk size 16 with a reasonable regularization, e.g., set the smallest log decay rate to -3. Note that e^-3= 0.04978706836, which is reasonably small. Then the maximum value will be e^{3*16}, within the range of fp32. I don't think this is a strong regularization cuz the model can still forget a lot information

SmerkyG commented 10 months ago

Thanks! We've been discussing it and we're still a bit uncertain about one thing regarding the subchunked numerical stability you mention when using fp32 and clipping to 0.05+. Maybe you could help clarify:

In subchunking, aren't we at some point comparing the key from the zero position within the chunk to a query from within the chunk's final subchunk? In this case the a_normalizer is from the last subchunk, meaning that the key could be up to 128-16=112 positions away. 0.05^-112>e+145 which would be infinity in fp32.

sustcsonglin commented 10 months ago

while the key could be up to 112 position away, it is always numerically stable for inter-sub-chunk computation. The only tricky part is inner-sub-chunk part, and the upper bound of position distance is 16

sustcsonglin commented 10 months ago
        for i in range(L//C):
            o = jnp.zeros((C,D))
            q = Q[i]
            a_q = A[i]
            a_normalizer = a_q[0]
            q = q * jnp.exp(a_q-a_normalizer[None])
            for j in range(i):
                k = K[j]
                v = V[j]
                a_kv = A[j]
                k = k * jnp.exp(a_normalizer[None]-a_kv)
                o += (q @ k.T) @ v

That is, the above code is always numerical stable because a_normalizer[None]-a_kv is always smaller than 0

SmerkyG commented 10 months ago

Thank you for the explanation! We've been finding good success (both speed and accuracy) with some variations on these themes in both pytorch and jax now.

sustcsonglin commented 10 months ago

Good to know! Hope our hardware-efficient method helps with your RWKV-v6 training :)