karpathy / llama2.c

Inference Llama 2 in one file of pure C
MIT License
16.95k stars 1.99k forks source link

I added bidirectional attention, and those who need it can study it. #427

Open win10ogod opened 9 months ago

win10ogod commented 9 months ago

Modify the original attention

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # use flash attention or a manual implementation?
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # flash implementation
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            # manual implementation
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            assert hasattr(self, 'mask')
            scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # final projection into the residual stream
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

This is my implementation

class BidirectionalAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # Create a bidirectional mask
        self.mask = self.generate_bidirectional_mask(args.max_seq_len)

    def generate_bidirectional_mask(self, max_seq_len):
        mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf"))
        mask = mask.triu(diagonal=1)  # Upper triangular part
        mask = mask + mask.transpose(-2, -1)  # Add the lower triangular part
        return mask

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # Grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # Make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Manual implementation with bidirectional mask
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = scores + self.mask[:, :, :seqlen, :seqlen]
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # Restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # Final projection into the residual stream
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

But I keep the original class name because it will generate errors.

win10ogod commented 9 months ago

My github fork: https://github.com/win10ogod/llama2.c/blob/master/model.py

win10ogod commented 9 months ago

image Training loss on TinyStories dataset.

dbl001 commented 7 months ago

How did you get the loss to '0.4' on 'tinystories'? I'm still at 3.5 after ~2,000 iterations.

# data
batch_size = 8 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 1024
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
# model
dim = 768
n_layers = 12
n_heads = 12
n_kv_heads = 12
multiple_of = 32
dropout = 0.0
# adamw optimizer
gradient_accumulation_steps = 4  # used to simulate larger batch sizes
learning_rate = 5e-5  # max learning rate
max_iters = 2000  # total number of training iterations
#weight_decay = 1e-1
weight_decay = 1e-5
beta1 = 0.9
#beta2 = 0.95
beta2 = 0.9999
grad_clip = 1.0  # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True  # whether to decay the learning rate
warmup_iters = 500  # how many steps to warm up for
# system
device = "mps"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = "float32"  # float32|bfloat16|float16
compile = False  # use PyTorch 2.0 to compile the model to be faster

$ python train.py --vocab_source=custom --vocab_size=4096
...
1200 | loss 3.4585 | lr 2.761321e-05 | 165287.20ms | mfu 0.39%
1210 | loss 3.5833 | lr 2.709195e-05 | 15341.66ms | mfu 0.40%
1220 | loss 3.4610 | lr 2.656976e-05 | 15354.62ms | mfu 0.40%
1230 | loss 3.5436 | lr 2.604689e-05 | 15425.17ms | mfu 0.40%
1240 | loss 3.5393 | lr 2.552356e-05 | 16301.26ms | mfu 0.40%
1250 | loss 3.4666 | lr 2.500000e-05 | 16101.18ms | mfu 0.41%
...
$ ./run out/model.bin -z data/tok4096.bin  -i "Once upon a time, there was a little girl named Lily."
Once upon a time, there was a little girl named Lily. old time magic time cake magic new little dog game time good red time time fun time friend nice white too we we and little time lovely nice even eyes time while b boy time, they,. all to to to, long and little teddy fur and,.", friend,,. dog,,. little cake again big., and and his he too bunny cold and they little it can his. to you so it shiny a she and it. girl and that you, Tim Tim I Tim and little the how so time long and had a he they Timmy they that they Sam people on and p thought even new the the took a " the a Tim the came he on round I it Lily the a so. Anna, Lily Sally the they up they her they, that with shiny his ch she a said there he the they get then a take the the b you he to the be a it sad and, holding went. blue she a the the he her the they Sue do mom Timmy had p funny with it the Lily the a always a little a big new Max, made saw saw be it learned he and he shiny go get the she said her all to Sam other swim, Lily how a all
achieved tok/s: 4.537851

Bi-directional training:

Screenshot 2023-12-18 at 6 34 58 AM

Unidirectional training loss after 2,000 iterations on 'MPS' was 1.7.

Screenshot 2023-12-18 at 9 49 49 AM
win10ogod commented 7 months ago

float32

The loss of bfloat16 drops more than that of float32.

win10ogod commented 5 months ago

@dbl001 New implementation but completed by qwen1.5-72b-chat:

class BiAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wq_back = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk_back = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv_back = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim * 2, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # use flash attention or a manual implementation?
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV (forward)
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # QKV (backward)
        xq_back, xk_back, xv_back = self.wq_back(x.flip(dims=[1])), self.wk_back(x.flip(dims=[1])), self.wv_back(x.flip(dims=[1]))
        xq_back = xq_back.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk_back = xk_back.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv_back = xv_back.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
        xq_back, xk_back = apply_rotary_emb(xq_back, xk_back, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xk_back = repeat_kv(xk_back, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        xv_back = repeat_kv(xv_back, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)
        xq_back = xq_back.transpose(1, 2)
        xk_back = xk_back.transpose(1, 2)
        xv_back = xv_back.transpose(1, 2)

        # flash implementation
        if self.flash:
            output_fw = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
            output_bw = torch.nn.functional.scaled_dot_product_attention(xq_back, xk_back, xv_back, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=False)
        else:
            # manual implementation
            scores_fw = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            scores_bw = torch.matmul(xq_back, xk_back.transpose(2, 3)) / math.sqrt(self.head_dim)
            assert hasattr(self, 'mask')
            scores_fw = scores_fw + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
            scores_bw = scores_bw + self.mask.flip(dims=[2, 3])[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
            scores_fw = F.softmax(scores_fw.float(), dim=-1).type_as(xq)
            scores_bw = F.softmax(scores_bw.float(), dim=-1).type_as(xq_back)
            scores_fw = self.attn_dropout(scores_fw)
            scores_bw = self.attn_dropout(scores_bw)
            output_fw = torch.matmul(scores_fw, xv)  # (bs, n_local_heads, seqlen, head_dim)
            output_bw = torch.matmul(scores_bw, xv_back)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads (forward and backward)
        output_fw = output_fw.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output_bw = output_bw.flip(dims=[1]).transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = torch.cat([output_fw, output_bw], dim=-1)

        # final projection into the residual stream
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output
    `