Closed faresobeid closed 10 months ago
Thanks for your interest! How worse the performance would be? Is L the first-level chunk size or the entire sequence length? Have you done any sanity check? like comparing the outputs and gradients with chunkwise or recurrent implementation? From the snippet itself I did not see problems and believed it is mathmatically equivalent to other two forms. My only concern on the code snippet is that
k = K[i]
v = V[i]
k = k * jnp.exp(a_normalizer[None]-a_q)
o += jnp.tril(q @ k.T) @ v
could potentially explode due to the same reason you described.
Thank you for the quick reply! I've checked and the outputs are equivalent between all implementations (chunkwise, recurrent, ...), although I havent checked for gradients. I'm using the code above as a part of my chunking so L would be the chunk size and C is the subchunk size. I haven't tried the log semiring implementation for that so I'll try and respond very soon if thats fine.
Ah this seems weirded. If the forward pass gives the same result, and the autograd works as normal, then i guess the gradient should be the same. What is your numerical precision, half or single? If is half, you can use single or even double precision for the intra-sub-chunk computation as a workaround to avoid log semiring implementation as the subchunk size itself is very small.
I'm using TPU's so from my understanding it will be done in bfloat16, I can try single or double precision but would you expect the performance loss to be from this? The performance difference is around 0.4 (cross entropy loss) with vocab size = 98. This is from early in training but it seems like that gap stays throughout training.
actually I use fp32 in my triton based implementation.
Yes this will negatively influence the running speed because tensor cores cannot do fp32 matmul (at least on Nvidia GPUs). But my sense is that the cost is relatively small because the subchunk size is as small as 16. And this is exactly the motivation for me to use subchunking: we can use bf16 matmul between different subchunks for most computation.
I just tried float32 and had no real improvements, although it should be more numerically stable.
I can provide my full code if you want, it is different than the full architecture in the paper, I’m using the rwkv architecture but that shouldn’t make the difference
sure I can help inspect the full code
# -*- coding: utf-8 -*-
"""rwkv-v6-best (17).ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1APJXUus7xotOQbITxoJsRPXjYKosFg8Q
"""
!pip install --upgrade -q pip
!pip install -q --upgrade jax[tpu] jaxlib
!pip install -q equinox
!pip install -q --upgrade matplotlib
import jax
import equinox as eqx
import equinox.nn as nn
import math
from jax import lax, random, numpy as jnp
import jax.nn as jnn
import optax
from itertools import repeat
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import time
import functools as ft
from typing import Callable, Any, Optional,Tuple, List
from jaxtyping import Array, Float
import jax.experimental.mesh_utils as mesh_utils
import jax.sharding as sharding
from jax import custom_vjp
from jax.experimental import pallas as pl
data = np.memmap('/kaggle/input/tinystories-newest/tinystories.bin', dtype=np.uint8, mode='r')
n = int(len(data)*0.95)
train_data = data[:n]
val_data = data[n:]
chars = ['<EOS>', '\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
vocab_size = len(chars)
ctoi = {c: i for i, c in enumerate(chars)}
itoc = {i: c for i, c in enumerate(chars)}
def encode(x):
x = x.split(' ')
tokens = []
for i,c in enumerate(x):
if c == '<EOS>':
tokens.append(ctoi[c])
else:
for ch in c:
tokens.append(ctoi[ch])
if i != len(x)-1:
tokens.append(ctoi[' '])
return {'input_ids': bytes(tokens), 'len': len(tokens)}
decode = lambda x: ''.join([itoc[i] for i in x])
def get_batch(batch_size,block_size,split):
data = train_data if split == 'train' else val_data
ix = np.random.randint(len(data) - block_size, size=(batch_size,))
x = np.stack([data[i: i + block_size] for i in ix])
y = np.stack([data[i + 1: i + block_size + 1] for i in ix])
return x,y
class Linear(eqx.Module):
weight: Array
def __init__(self,in_features,out_features,key,scale=1.):
if out_features > in_features:
gain = math.sqrt(out_features / in_features)
else:
gain = 1.
self.weight = jnn.initializers.orthogonal(gain*scale)(key,(out_features, in_features))
def __call__(self, x):
x = self.weight @ x
return x
class RMSNorm(eqx.Module):
dim: int = eqx.static_field()
use_weight: bool = eqx.static_field()
weight: Optional[Array]
def __init__(self, dim, use_weight=True):
self.dim = dim
self.use_weight = use_weight
self.weight = jnp.ones(dim) if use_weight else None
def __call__(self,x):
inv_rms = jax.lax.rsqrt(jnp.sum(x**2) + 1e-5)
out = jnp.sqrt(self.dim) * inv_rms * x
if self.use_weight:
out = self.weight * out
return out
def time_shift(x,sx):
return jnp.concatenate((sx[None], x[:-1, :])),x[-1,:]
class ChannelMix(eqx.Module):
ln: RMSNorm
Wk: Linear
Wr: Linear
Wv: Linear
mu_k: Array
mu_r: Array
def __init__(self,L_id,L,D,key):
keys = random.split(key,3)
self.ln = RMSNorm(D)
self.Wk = Linear(D,int(3.5*D),keys[0])
self.Wr = Linear(D,D,keys[1])
self.Wv = Linear(int(3.5*D),D,keys[2],scale=0)
ratio_0_to_1 = L_id / (L - 1)
ratio_1_to_almost0 = 1.0 - (L_id / L)
ddd = (jnp.arange(D)/D)[None]
self.mu_k = jnp.power(ddd, ratio_1_to_almost0)
self.mu_r = jnp.power(ddd, ratio_1_to_almost0)
def __call__(self,x,state):
xx = jax.vmap(self.ln)(x)
sx,state = time_shift(xx,state)
sx = sx - xx
kx = xx + sx * self.mu_k
rx = xx + sx * self.mu_r
k = jax.vmap(self.Wk)(kx)
r = jax.vmap(self.Wr)(rx)
k = jnp.square(jnn.relu(k))
out = jnn.sigmoid(r) * jax.vmap(self.Wv)(k)
return out + x,state
class TimeMix(eqx.Module):
ln: RMSNorm
gn: nn.GroupNorm
x_maa: Array
r_maa: Array
w_maa: Array
k_maa: Array
v_maa: Array
g_maa: Array
tm_w1: Array
tm_w2: Array
td_w1: Array
td_w2: Array
w: Array
# u: Array
Wk: Linear
Wv: Linear
Wr: Linear
Wg: Linear
Wo: Linear
H: int = eqx.static_field()
C: int = eqx.static_field()
def __init__(self,L_id,L,H,D,key):
keys = random.split(key,7)
self.ln = RMSNorm(D)
self.gn = nn.GroupNorm(H,D)
ratio_0_to_1 = L_id / (L - 1)
ratio_1_to_almost0 = 1.0 - (L_id / L)
ddd = (jnp.arange(D)/D)[None]
self.x_maa = jnp.power(ddd, ratio_1_to_almost0)
self.r_maa = jnp.power(ddd, ratio_1_to_almost0)
self.w_maa = jnp.power(ddd, ratio_1_to_almost0)
self.k_maa = jnp.power(ddd, ratio_1_to_almost0)
self.v_maa = jnp.power(ddd, ratio_1_to_almost0)
self.g_maa = jnp.power(ddd, ratio_1_to_almost0)
TIME_MIX_EXTRA_DIM = 32
self.tm_w1 = random.uniform(keys[0],(D, TIME_MIX_EXTRA_DIM * 5),minval=-0.01,maxval=0.01)
self.tm_w2 = jnp.zeros((5, TIME_MIX_EXTRA_DIM, D))
W_MIX_EXTRA_DIM = 64
self.td_w1 = random.uniform(keys[1],(D, W_MIX_EXTRA_DIM),minval=-0.01,maxval=0.01)
self.td_w2 = jnp.zeros((W_MIX_EXTRA_DIM, D))
self.w = (-6 + 5 * (jnp.arange(D) /(D - 1)) ** (0.7 + 1.3 * ratio_0_to_1)).reshape((1,H,-1))
# n = jnp.arange(D)
# zigzag = ((n + 1) % 3 - 1) * 0.1
# self.u = (ratio_0_to_1 * (1 - (n / max(D - 1, 1))) + zigzag).reshape((H, -1, 1))
self.Wk = Linear(D,D,keys[2])
self.Wv = Linear(D,D,keys[3])
self.Wr = Linear(D,D,keys[4])
self.Wg = Linear(D,D,keys[5])
self.Wo = Linear(D,D,keys[6],0)
self.H = H
self.C = 128
def rwkv(self,k,v,r,w,s):
w = jnp.exp(-jnp.exp(w))
k,v,r = map(lambda x: x[:,None,:],[k,v,r])
w = w[:,:,None]
def loop(s,W_t):
k_t,v_t,r_t,w_t = W_t
s = w_t * s + k_t.T @ v_t
out_t = r_t @ s
return s,out_t.squeeze(0)
s,out = jax.lax.scan(loop,s,(k,v,r,w))
return s,out
# def parallel_log(self,Q,K,A):
def parallel_subchunk(self,Q,K,V,A):
C = 16
L,D = Q.shape
Q,K,V,A = map(lambda x: x.reshape((L//C,C,-1)),[Q,K,V,A])
out = jnp.zeros((L//C,C,D))
for i in range(L//C):
o = jnp.zeros((C,D))
q = Q[i]
a_q = A[i]
a_normalizer = a_q[0]
q = q * jnp.exp(a_q-a_normalizer[None])
for j in range(i):
k = K[j]
v = V[j]
a_kv = A[j]
k = k * jnp.exp(a_normalizer[None]-a_kv)
o += (q @ k.T) @ v
k = K[i]
v = V[i]
k = k * jnp.exp(a_normalizer[None]-a_q)
o += jnp.tril(jnp.matmul(q, k.T, precision=jax.lax.Precision('float32'))) @ v
out.at[i].set(o)
return out.reshape((L,D))
def chunkwise(self,q,k,v,w):
w = -jnp.exp(w)
C = self.C
L,D_q = q.shape
D_v = v.shape[-1]
q,k,v,w = map(lambda x: x.reshape((L//C,C,-1)),[q,k,v,w])
A_1 = jnp.sum(w,1)[:,None]
A_2 = jax.lax.cumsum(w,axis=1,reverse=True)-w
A_3 = jnp.cumsum(w,axis=1)
s = jnp.zeros((D_q,D_v))
def loop(s,W_t):
q_t,k_t,v_t,A_1_t,A_2_t,A_3_t,w_t = W_t
s_new = jnp.exp(A_1_t).T * s + (jnp.exp(A_2_t) * k_t).T @ v_t
q_ = q_t * jnp.exp(A_3_t)
# o = q_ @ s + self.parallel_subchunk(q_t,k_t,v_t,A_3_t)
o = q_ @ s + jnp.tril(q_ @ (k_t / (jnp.exp(A_3_t)+1e-12)).T) @ v_t
return s_new,o
_,out = jax.lax.scan(loop,s,(q,k,v,A_1,A_2,A_3,w))
return out.reshape((L,D_v))
def __call__(self,x,x_state,s,output_state):
T,D = orig_shape = x.shape
xx = jax.vmap(self.ln)(x)
sx,x_state = time_shift(xx,x_state)
sx = sx - xx
xxx = xx + sx * self.x_maa
xxx = jnn.tanh(xxx @ self.tm_w1).reshape((T, 5, -1)).swapaxes(0, 1)
mw, mk, mv, mr, mg = (xxx @ self.tm_w2).reshape((5, T, -1))
wx = xx + sx * (self.w_maa + mw)
kx = xx + sx * (self.k_maa + mk)
vx = xx + sx * (self.v_maa + mv)
rx = xx + sx * (self.r_maa + mr)
gx = xx + sx * (self.g_maa + mg)
w = self.w + (jnn.tanh(wx @ self.td_w1) @ self.td_w2).reshape((T,self.H,-1))
k = jax.vmap(self.Wk)(kx)
v = jax.vmap(self.Wv)(vx)
r = jax.vmap(self.Wr)(rx)
g = jax.vmap(self.Wg)(gx)
k,v,r,w = map(lambda x: x.reshape((T,self.H,-1)).swapaxes(0,1),[k,v,r,w])
g = jnn.silu(g)
if output_state:
s,out = jax.vmap(self.rwkv)(k,v,r,w,s)
else:
s = None
out = jax.vmap(self.chunkwise)(r,k,v,w)
out = out.swapaxes(0,1).reshape(orig_shape)
out = jax.vmap(self.gn)(out)
out = out.astype(x.dtype) * g
out = jax.vmap(self.Wo)(out)
return x + out,x_state,s
class Block(eqx.Module):
time_mix: TimeMix
channel_mix: ChannelMix
empty_x_state: Array = eqx.static_field()
empty_attn_state: Array = eqx.static_field()
empty_ffn_state: Array = eqx.static_field()
def __init__(self,L_id,L,H,D,key):
keys = random.split(key,2)
self.time_mix = TimeMix(L_id,L,H,D,keys[0])
self.channel_mix = ChannelMix(L_id,L,D,keys[1])
self.empty_x_state = jnp.zeros(D)
self.empty_attn_state = jnp.zeros((H,D//H,D//H))
self.empty_ffn_state = jnp.zeros(D)
def __call__(self, x, block_state, output_state):
if block_state is None:
block_state = (self.empty_x_state,self.empty_attn_state,self.empty_ffn_state)
x_state, attn_state, ffn_state = block_state
x, x_state, attn_state = self.time_mix(x, x_state, attn_state, output_state)
x, ffn_state = self.channel_mix(x, ffn_state)
return x,(x_state, attn_state, ffn_state)
class RWKV(eqx.Module):
emb: Array
ln_b: RMSNorm
blocks: List[Block]
ln_f: RMSNorm
final: Linear
def __init__(self,L,H,D,V,key):
keys = random.split(key,3)
keys_blocks = random.split(keys[0],L)
self.emb = jnn.initializers.orthogonal(1e-4)(keys[1],(V,D))
self.ln_b = RMSNorm(D)
self.blocks = [Block(i,L,H,D,key) for i,key in enumerate(keys_blocks)]
self.ln_f = RMSNorm(D)
self.final = Linear(D,V,keys[2],0.5)
def __call__(self,x, block_states=None, output_state=None):
x = jax.vmap(lambda x: self.emb[x])(x)
x = jax.vmap(self.ln_b)(x)
next_states = []
if block_states is None:
block_states = repeat(None)
for block, state in zip(self.blocks, block_states):
x,new_state = block(x, state, output_state)
if output_state:
next_states.append(new_state)
x = jax.vmap(self.ln_f)(x)
x = jax.vmap(self.final)(x)
if output_state:
return x, next_states
return x
def generate(self,input_tokens,max_len,key):
input_tokens = input_tokens.astype(jnp.int32)
keys = random.split(key,max_len)
_,new_states = self.__call__(input_tokens[:-1].reshape((-1,)),None,True)
next_token = input_tokens[-1]
all_tokens = jnp.empty((max_len),dtype=jnp.int32)
def loop(i,carry):
next_token,new_states,all_tokens = carry
logits,new_states = self.__call__(next_token.reshape((-1,)),new_states,True)
next_token_logits = logits[-1]/0.9
top_logits, top_tokens = jax.lax.top_k(next_token_logits, min(40, len(next_token_logits)))
token_idx = jax.random.categorical(keys[i], top_logits)
next_token = top_tokens[token_idx]
all_tokens = all_tokens.at[i].set(next_token)
return next_token,new_states,all_tokens
_,_,all_tokens = jax.lax.fori_loop(0,max_len,loop,(next_token,new_states,all_tokens))
print(decode(all_tokens.tolist()))
@eqx.filter_value_and_grad
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
loss = optax.softmax_cross_entropy_with_integer_labels(pred_y, y).mean()
return loss
@eqx.filter_jit(donate="all")
def make_step(model, x, y, opt_state, opt_update):
loss,grads = loss_fn(model, x, y)
updates, opt_state = opt_update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
def get_shard():
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
shard = sharding.PositionalSharding(devices)
return shard
def print_avg_loss(losses,n):
loss = losses[-n:]
sum_ = sum(loss)
mean = sum_/len(loss)
print('Avg of last ' + str(n) + ' losses: ' + str(mean))
def print_params(model):
param_count = sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(model, eqx.is_array)))
print('Number of Parameters: ' + str(param_count/1e6) + ' (M)')
def train(seed):
N = 10000
B = 128
L = 4
H = 4
D = 128
V = vocab_size
S = 1024
sample_len = 1024
LR = 0.01
key = random.PRNGKey(seed)
keys = random.split(key,3)
key_sampling = keys[0]
model = RWKV(L,H,D,V,keys[1])
print_params(model)
schedule = optax.linear_onecycle_schedule(N,LR)
opt = optax.chain(optax.clip_by_global_norm(1.),optax.adamw(schedule,b1=0.9,b2=0.98,weight_decay=0.))
opt_state = opt.init(eqx.filter(model, eqx.is_array))
shard = get_shard()
losses = []
pbar = tqdm(range(N),desc='training:', position=0, leave=True)
for i in pbar:
x,y = get_batch(B,S,'train')
x,y = jax.device_put((x,y),shard)
model, opt_state, loss = make_step(model, x, y, opt_state, opt.update)
if jnp.isnan(loss):
print(i,'nans!!!!!!')
break
pbar.set_postfix(train_loss=loss.item())
losses.append(np.asarray(loss))
if i % (N//10) == 0 or (i == N - 1):
print_avg_loss(losses,10)
x,_ = get_batch(B,S,'val')
prime = x[0][:50]
prime_str = decode(prime.tolist())
print(prime_str, "\n", "*" * 40)
key_sample = random.fold_in(key_sampling,i)
start = time.time()
model.generate(prime,sample_len,key_sample)
end = time.time()
print('Time Taken: ' + str(end-start))
plt.plot(np.log(losses))
plt.show()
return losses
Sorry for the messy code!
(Though not related to the subchunking)
s_new = jnp.exp(A_1_t).T * s + (jnp.exp(A_2_t) * k_t).T @ v_t
should be
s_new = jnp.exp(A_1_t) * s + (jnp.exp(A_2_t) * k_t).T @ v_t
?
Oh ok I just thought from the paper, alpha was 1xd, so making beta = 1, then G would be alpha.T which is d x 1. Thank you for pointing it out though!
A_1 = jnp.sum(w,1)[:,None]
"None" will make the last dimension size 1 and if do the transpose you will have the gating in the V dimension instead of the K dimension?
W is shape L//C x C x D, A_1 = jnp.sum(w,1)[:,None]
will make A_1 shape L//C x 1 x D, so A_1_t will be of shape 1 x D. So if im not mistaken the transpose would have gating in the K dimension which fits in with the paper. If you replace D with D_k then you would have to transpose to multiply by the state of shape D_k x D_v.
Sorry im not familiar with Jax at all. my guess is jnp.sum(w,1)
will give L//C x D, and then [:, None]
will make the resulting size L//C x D x 1. Why is L//C x 1 x D. Is this specific for Jax?
From my understanding if w is shape L//C x D, then w[:,None] is the same as w[:,None,:] which is L//C x 1 x D.
ah you are right. Just misread it as `[..., None] ' ...
I read the code throughout and it looks good for me. Still have no idea why subchunking failed. Maybe need some gradient checking. Given that i am not familiar with Jax, i cannot help debug further.
That’s fine, I can try and rewrite this into PyTorch to see if there is any difference and maybe then could you help debug?
Sure!
Would it be better to remove all the rwkv specific stuff for debugging? It would also be easier to code
yep! i believe so. Here is my previous debugging file for the triton implementation: link and everything looks good.
As an update I realised that the line out.at[i].set(o)
should be out = out.at[i].set(o)
so out stayed as zero, so it was a stupid mistake. Now the performance is much better however very unstable even at float32, I will try using the log semiring to fix that and update you on how that goes. If possible, do you know an optimal implementation for that? Either way, thank you for helping and sorry for wasting your time for a stupid mistake.
Ah good to hear this! Fp32 without regularization on the decay rate is prone to numerical issue! My suggestion on the next step:
Thank you for the reply!
Yes. This is the partial motivation of subchunking! If you use a small first-level chunk size, you have better numerical stability but you need to materialize more states into HBMs -> slow. If you use a large first-level chunk size, the numerical issue will be a nightmare. Subchunking avoids the two issues simultaneously. In my experience, fp32 is enough to have a good numerical stability for subchunk size 16 with a reasonable regularization, e.g., set the smallest log decay rate to -3. Note that e^-3= 0.04978706836, which is reasonably small. Then the maximum value will be e^{3*16}, within the range of fp32. I don't think this is a strong regularization cuz the model can still forget a lot information
Thanks! We've been discussing it and we're still a bit uncertain about one thing regarding the subchunked numerical stability you mention when using fp32 and clipping to 0.05+. Maybe you could help clarify:
In subchunking, aren't we at some point comparing the key from the zero position within the chunk to a query from within the chunk's final subchunk? In this case the a_normalizer is from the last subchunk, meaning that the key could be up to 128-16=112 positions away. 0.05^-112>e+145 which would be infinity in fp32.
while the key could be up to 112 position away, it is always numerically stable for inter-sub-chunk computation. The only tricky part is inner-sub-chunk part, and the upper bound of position distance is 16
for i in range(L//C):
o = jnp.zeros((C,D))
q = Q[i]
a_q = A[i]
a_normalizer = a_q[0]
q = q * jnp.exp(a_q-a_normalizer[None])
for j in range(i):
k = K[j]
v = V[j]
a_kv = A[j]
k = k * jnp.exp(a_normalizer[None]-a_kv)
o += (q @ k.T) @ v
That is, the above code is always numerical stable because a_normalizer[None]-a_kv
is always smaller than 0
Thank you for the explanation! We've been finding good success (both speed and accuracy) with some variations on these themes in both pytorch and jax now.
Good to know! Hope our hardware-efficient method helps with your RWKV-v6 training :)
I have been experimenting with GLA and implementing it in pure pytorch or Jax (no triton) and have run into performance problems with subchunking. When using the chunkwise method on its own, I get promising results in terms of speed and performance.
However to make it work, I have to add an epsilon
k/(A + 1e-12)
or elsek
will explode very quickly which is expected.To avoid epsilon, I tried subchunking (almost the same as the pseudocode) however although speed is not a big problem, the performance becomes significantly worse.
The subchunking performance ends up being much worse than both chunkwise and recurrent.
Is there any reason for this and how could this be fixed?
PS: The main purpose of this is to hopefully use the GLA methods for a no cuda or triton RWKV v6
Code for subchunking in Jax: