chrisociepa / allamo

Simple, hackable and fast implementation for training/finetuning medium-sized LLaMA-based models
MIT License
156 stars 15 forks source link

hope to merge llama3 into your project, Thank you for your help #18

Open sankexin opened 1 week ago

sankexin commented 1 week ago

This is a great project, Open source training from scratch, simple and easy to use, especially suitable for ordinary people.

The currently sota algorithm models are highly similar to llama3. I hope everyone can train llama3 from scratch and mayby can help many interesting new algorithms person to promote social progress, propose new algorithms based on this none. Therefore, I have accorded to your project format to preliminarily implement llama3 and hope to help merge it into your project.

Replace the code in allamo/model/modl.py with the following code, then it can works:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
import inspect
from dataclasses import dataclass
from typing import Optional, Tuple
# import fairscale.nn.model_parallel.initialize as fs_init
import torch
from torch import nn
import torch.nn.functional as F
# from fairscale.nn.model_parallel.layers import (
#     ColumnParallelLinear,
#     RowParallelLinear,
#     VocabParallelEmbedding,
# )
from allamo.logging import logger
from allamo.model.attentions import attention_version

## simple llama3
@dataclass
class AllamoTransformerConfig:
    dim: int = 2048
    n_layers: int = 12
    n_heads: int = 12
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048

'''
## origin llama3
@dataclass
class AllamoTransformerConfig:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048
'''

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

class Attention(nn.Module):
    def __init__(self, args: AllamoTransformerConfig):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = self.n_kv_heads
        # model_parallel_size = fs_init.get_model_parallel_world_size()
        # self.n_local_heads = args.n_heads // model_parallel_size
        # self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        '''
        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )
        '''

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

        '''
        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        '''

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: AllamoTransformerConfig):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

class Transformer(nn.Module):
    def __init__(self, params: AllamoTransformerConfig):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        # self.tok_embeddings = VocabParallelEmbedding(
        #     params.vocab_size, params.dim, init_method=lambda x: x
        # )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
        # self.output = ColumnParallelLinear(
        #     params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        # )

        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        )

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h).float()
        return output

class AllamoTransformer(nn.Module):
    def __init__(self, params: AllamoTransformerConfig):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        # self.tok_embeddings = VocabParallelEmbedding(
        #     params.vocab_size, params.dim, init_method=lambda x: x
        # )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
        # self.output = ColumnParallelLinear(
        #     params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        # )

        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        )

        self.log_estimated_size()

    def estimate_size(self):
        """
        Return the number of parameters and their size in the model.
        """
        params = 0
        bytes = 0
        for p in self.parameters():
            params += p.numel()
            bytes += p.numel() * p.element_size()
        for b in self.buffers():
            # don't count buffers as params
            bytes += b.numel() * b.element_size()
        return params, bytes

    def log_estimated_size(self):
        self.model_num_params, self.model_num_bytes = self.estimate_size()
        model_params = self.model_num_params / 1e6
        model_bytes = self.model_num_bytes / 1024**2
        logger.info(f"Model parameters: {model_params:.2f}M, Est. Size: {model_bytes:.3f}MB")

    def forward(self, 
        input_ids: torch.Tensor, 
        input_pos: Optional[int] = 0, 
        target_ids: Optional[torch.Tensor] = None, 
        target_weights: Optional[torch.Tensor] = None, 
        attn_mask: Optional[torch.Tensor] = None,
        ignore_index: Optional[int] = -100,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        tokens = input_ids
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        start_pos = input_pos = 0
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        final_embeddings = self.norm(h)

        if target_ids is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(final_embeddings)
            if target_weights is None:
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1), ignore_index=ignore_index)
            else:
                loss = (target_weights.view(-1) * F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction="none")).sum()
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            # logits = self.lm_head(final_embeddings[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = self.output(final_embeddings[:, [-1], :])
            loss = None

        return logits, loss, h

    def configure_optimizers(self, config, device_type):
        # start with all of the candidate parameters
        param_dict = {param_name: p for param_name, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {param_name: p for param_name, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        def is_weight_decay_forbidden(param_name):
            return param_name.endswith('.bias') or param_name.endswith('_norm.weight') or param_name == 'norm.weight'
        decay_params = [p for n, p in param_dict.items() if not is_weight_decay_forbidden(n)]
        nodecay_params = [p for n, p in param_dict.items() if is_weight_decay_forbidden(n)]
        optim_groups = [
            {'params': decay_params, 'weight_decay': config.weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        logger.info(f"Decayed parameter tensors: {len(decay_params):,}, with {num_decay_params:,} parameters")
        logger.info(f"Non-decayed parameter tensors: {len(nodecay_params):,}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=config.learning_rate, betas=(config.beta1, config.beta2), **extra_args)
        logger.info(f"Using fused AdamW: {use_fused}")

        return optimizer

    @torch.no_grad()
    def generate_embeddings(self, tokens):
        x = self.token_embeddings(tokens)
        x = self.apply_layers(x)
        x = self.norm(x)
        return x

    @torch.no_grad()
    def generate(self, tokens, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of tokens (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        if tokens.size(1) > self.config.block_size:
            logger.info(
                f"Input of {tokens.size(1)} tokens exceeds limit {self.config.block_size}. "
                f"Initial tokens will be dropped to fit."
            )
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            if tokens.size(1) > self.config.block_size:
                tokens = tokens[:, -self.config.block_size:]
            # forward the model to get the logits for the tokens
            logits = self(tokens)[0]
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)
            # append next token to the running sequence and continue
            tokens = torch.cat((tokens, next_token), dim=1)

        return tokens

result like this:

model:  AllamoTransformer(
  (tok_embeddings): Embedding(50307, 2048)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=2048, out_features=2040, bias=False)
        (wk): Linear(in_features=2048, out_features=2040, bias=False)
        (wv): Linear(in_features=2048, out_features=2040, bias=False)
        (wo): Linear(in_features=2040, out_features=2048, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=2048, out_features=5632, bias=False)
        (w2): Linear(in_features=5632, out_features=2048, bias=False)
        (w3): Linear(in_features=2048, out_features=5632, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=2048, out_features=50307, bias=False)
)
2024-11-21 02:55:22,447 - root - INFO - New model initialized from scratch
2024-11-21 02:55:24,200 - root - INFO - Compiling model
2024-11-21 02:55:25,391 - root - INFO - Model compiled and ready to use
2024-11-21 02:55:25,393 - root - INFO - Decayed parameter tensors: 86, with 821,833,728 parameters
2024-11-21 02:55:25,393 - root - INFO - Non-decayed parameter tensors: 25, with 51,200 parameters
2024-11-21 02:55:25,394 - root - INFO - Using fused AdamW: True
2024-11-21 02:55:25,394 - root - INFO - Cosing decay learning rate enabled. Currect learning rate: 0.0
2024-11-21 02:55:25,395 - root - INFO - Starting training (run id: 6b1b8fec-3f23-43d1-a94c-0fec2e31e28f, world size: 1) with configuration:
AllamoConfiguration(load_configuration=True, init_from='scratch', checkpoint_path=None, seed=1337, data_dir='../data/', out_dir='../data/out-allamo-1B/', log_checkpoint_md5_on_load=False, log_checkpoint_md5_on_epoch=False, ignore_last_checkpoint_backup=False, checkpoint_interval=1000, save_optimizer_checkpoint=True, optimizer_checkpoint_interval=None, save_best_checkpoint=False, save_checkpoint_on_dataset_reload=False, distributed_checkpoint=False, config_override_check_interval=None, config_override_path=None, eval_interval=1000, eval_iters=200, eval_only=False, log_interval=1, vocab_size=50307, tiktoken_tokenizer_name=None, hf_tokenizer_path=None, wandb_log=False, wandb_project='allamo', wandb_run_name='allamo-1B', gradient_checkpointing=False, gradient_accumulation_steps=264, batch_size=2, block_size=2048, sliding_window=None, dataset='allamo_1B_dataset', dataset_train_files=None, dataset_validation_files=None, dataset_train_file_prefix='train.', dataset_validation_file_prefix='val.', dataset_train_processed_files_count=0, dataset_seq_train=True, dataset_seq_train_start=None, dataset_buffer=False, batch_size_initial=2, batch_size_max_iter=2000, batch_size_schedule=False, batch_size_max=64, grad_accum_initial=2, grad_accum_max_iter=2000, grad_accum_schedule=False, grad_accum_max=8, rope_freq_base=10000, rope_freq_scale=1.0, n_layer=20, n_head=16, head_size=128, num_kv_heads=None, n_embd=2048, intermediate_size=None, dropout=0.01, bias=False, multiple_of=256, norm_eps=1e-06, learning_rate=0.0003, num_train_epochs=None, max_iters=38000, weight_decay=0.1, beta1=0.9, beta2=0.95, grad_clip=1.0, decay_lr=True, warmup_iters=3800, lr_decay_iters=38000, lr_decay_reset_iters=3800, min_lr=0.0002, backend='nccl', device='cuda:0', dtype='float16', compile=True, compile_mode='default', mfu_flops_peak=-1.0, ignore_index=-100, pad_token_id=-1, weighted_loss=False, weighted_loss_method='allamo', adaptive_learning_rate=False, fsdp_sharding_strategy='FULL_SHARD', epoch_completion_hook_program=None, regular_checkpoint_hook_program=None, dpo_chosen_beta=0.5, dpo_rejected_beta=0.1, dpo_penalty_lambda=50.0, reference_checkpoint_name='ref_ckpt', training_type='pre', attention_implementation='sdpa', tensor_parallel_degree=1, prompt='\n', num_samples=1, max_new_tokens=50, temperature=0.8, top_k=100)
/opt/conda/lib/python3.8/site-packages/torch/_inductor/lowering.py:1778: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
2024-11-21 02:56:59,144 - root - INFO - iter 0: train loss=11.0163 ppl=60858.3906 acc=0.0000 (best loss=100.0000), val loss=10.9931 ppl=59462.6133 acc=0.0000 (best loss=100.0000), tokens 0
2024-11-21 02:58:18,635 - root - INFO - Epoch 1 finished
2024-11-21 02:58:36,173 - root - INFO - Epoch 2 finished
2024-11-21 02:58:53,681 - root - INFO - Epoch 3 finished
2024-11-21 02:59:11,465 - root - INFO - Epoch 4 finished
2024-11-21 02:59:11,740 - root - INFO - iter 0: loss 11.0166, ppl 60873.3672, acc 0.0000, iter time 225834.95ms, tokens 1,081,344, lr 0.00000000, mfu n/a, mtu 58.08%, epoch 4, ETA: N/A
2024-11-21 02:59:11,743 - root - INFO - saving model checkpoint to ../data/out-allamo-1B/model_epoch_0.pt
2024-11-21 02:59:16,395 - root - INFO - model checkpoint saved in ../data/out-allamo-1B/model_epoch_0.pt
2024-11-21 02:59:16,396 - root - INFO - saving config checkpoint to ../data/out-allamo-1B/config_epoch_0.json
2024-11-21 02:59:16,397 - root - INFO - checkpoint files saved in ../data/out-allamo-1B/
2024-11-21 02:59:34,108 - root - INFO - Epoch 5 finished
2024-11-21 02:59:51,808 - root - INFO - Epoch 6 finished
2024-11-21 03:00:09,532 - root - INFO - Epoch 7 finished
2024-11-21 03:00:27,363 - root - INFO - Epoch 8 finished
2024-11-21 03:00:27,562 - root - INFO - iter 1: loss 11.0166, ppl 60873.3672, acc 0.0000, iter time 71132.28ms, tokens 2,162,688, lr 0.00000008, mfu n/a, mtu 99.64%, epoch 8, ETA: 3189:23:05
2024-11-21 03:00:27,564 - root - INFO - saving model checkpoint to ../data/out-allamo-1B/model_epoch_4.pt
2024-11-21 03:00:32,184 - root - INFO - model checkpoint saved in ../data/out-allamo-1B/model_epoch_4.pt
2024-11-21 03:00:32,185 - root - INFO - saving config checkpoint to ../data/out-allamo-1B/config_epoch_4.json
2024-11-21 03:00:32,186 - root - INFO - checkpoint files saved in ../data/out-allamo-1B/
sankexin commented 6 days ago

well done! Another way to get the original llama3 without changing the code:

Modify values similar to these names in "allamo/train_configs/train_1B.json"

dropout: 0,
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
rope_theta: float = 500000
max_batch_size: int = 32
max_seq_len: int = 2048