likejazz / llama3.np

llama3.np is a pure NumPy implementation for Llama 3 model.
MIT License
958 stars 73 forks source link

a simplifed version #2

Open scturtle opened 4 months ago

scturtle commented 4 months ago
import numpy as np

class ModelArgs:
    dim = 288
    n_layers = 6
    n_heads = 6
    norm_eps = 1e-6

def build_cos_sin_cache(head_dim, seq_len, base=10000):
    theta = 1. / (base ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim))
    seq_idx = np.arange(seq_len, dtype=np.float32)
    idx_theta = np.outer(seq_idx, theta)
    return np.cos(idx_theta), np.sin(idx_theta)

cos_cached, sin_cached = build_cos_sin_cache(ModelArgs.dim // ModelArgs.n_heads, seq_len=256)

def rope(x, start_pos):
    seq_len = x.shape[1]
    r = np.zeros_like(x)
    cos = cos_cached[start_pos: start_pos + seq_len][None, :, None, :]
    sin = sin_cached[start_pos: start_pos + seq_len][None, :, None, :]
    r[:, :, :, 0::2] = x[:, :, :, 0::2] * cos - x[:, :, :, 1::2] * sin
    r[:, :, :, 1::2] = x[:, :, :, 1::2] * cos + x[:, :, :, 0::2] * sin
    return r

def softmax(x):
    e = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e / np.sum(e, axis=-1, keepdims=True)

def silu(x):
    return x * (1 / (1 + np.exp(-x)))

def ffn(x, up_wgt, gate_wgt, down_wgt):
    return (silu(x @ gate_wgt) * (x @ up_wgt)) @ down_wgt

def rmsnorm(x, eps=ModelArgs.norm_eps):
    return x / np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)

def attn(x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache):
    q = x @ q_wgt
    k = x @ k_wgt
    v = x @ v_wgt

    B, L, d = x.shape
    q = q.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))
    k = k.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))
    v = v.reshape((B, L, ModelArgs.n_heads, d // ModelArgs.n_heads))

    q = rope(q, start_pos)
    k = rope(k, start_pos)

    if cache:
        k_cache, v_cache = cache
        k = np.concatenate([k_cache, k], axis=1)
        v = np.concatenate([v_cache, v], axis=1)
    cache[:] = [k, v]

    n_rep = q.shape[-2] // k.shape[-2]
    k = np.repeat(k, n_rep, axis=-2)
    v = np.repeat(v, n_rep, axis=-2)

    x = np.einsum('...qhd,...khd->...hqk', q, k)
    if L > 1:
        mask = (1 - np.tri(x.shape[-1], dtype=x.dtype)) * -1e10
    else:
        mask = 0
    x = softmax(x * q.shape[-1] ** -0.5 + mask)
    x = np.einsum('...hqk,...khd->...qhd', x, v)
    x = x.reshape(x.shape[:-2] + (-1,))
    return x @ o_wgt

def block(x, start_pos, layer_id, weights, cache):
    rms_wgt_in = weights[f"model.layers.{layer_id}.input_layernorm.weight"]
    q_wgt = weights[f"model.layers.{layer_id}.self_attn.q_proj.weight"]
    k_wgt = weights[f"model.layers.{layer_id}.self_attn.k_proj.weight"]
    v_wgt = weights[f"model.layers.{layer_id}.self_attn.v_proj.weight"]
    o_wgt = weights[f"model.layers.{layer_id}.self_attn.o_proj.weight"]
    rms_wgt_out = weights[f"model.layers.{layer_id}.post_attention_layernorm.weight"]
    up_wgt = weights[f"model.layers.{layer_id}.mlp.up_proj.weight"]
    gate_wgt = weights[f"model.layers.{layer_id}.mlp.gate_proj.weight"]
    down_wgt = weights[f"model.layers.{layer_id}.mlp.down_proj.weight"]

    norm_x = rmsnorm(x) * rms_wgt_in
    x += attn(norm_x, start_pos, q_wgt, k_wgt, v_wgt, o_wgt, cache)
    norm_x = rmsnorm(x) * rms_wgt_out
    x += ffn(norm_x, up_wgt, gate_wgt, down_wgt)
    return x

def llama3(x, start_pos, weights, caches):
    x = weights["model.embed_tokens.weight"][x]
    for i in range(ModelArgs.n_layers):
        x = block(x, start_pos, layer_id=i, weights=weights, cache=caches[i])
    x = rmsnorm(x) * weights["model.norm.weight"]
    return x @ weights["lm_head.weight"]

def main():
    from tokenizer import Tokenizer
    tokenizer = Tokenizer("./tokenizer.model.np")

    weights = dict(np.load("./stories15M.model.npz"))
    for k in weights:
        if k.endswith('proj.weight') or k == "lm_head.weight":
            weights[k] = weights[k].T

    prompt = "I have a dream"
    print(f"{prompt}", end="", flush=True)
    x = np.array([tokenizer.encode(prompt)])
    caches = [[] for _ in range(ModelArgs.n_layers)]
    for start_pos in range(x.shape[1], 56):
        start_pos = 0 if not caches[0] else start_pos
        logits = llama3(x, start_pos, weights, caches)
        x = np.argmax(logits[:, -1, :], axis=-1, keepdims=True)
        print(tokenizer.decode(x[0]), end="", flush=True)

main()
likejazz commented 4 months ago

Awesome! your code is very clean and readable. I love the style, Thanks!