fkodom / dilated-attention-pytorch

(Unofficial) Implementation of dilated attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens" (https://arxiv.org/abs/2307.02486)
MIT License
47 stars 9 forks source link

Training on yet-another-retnet script #4

Open Akbarable opened 1 year ago

Akbarable commented 1 year ago

Hello Frank!

I love what you have created, and am having a great time going through and parsing through your implementation of the paper. It appears you have nailed the dilated attention calculation method.

Here are my observations so far:

import os
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import tiktoken
import torch
from lightning import Fabric, seed_everything
from lightning.fabric.loggers import TensorBoardLogger
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

#from yet_another_retnet.retnet import RetNet
from yet_another_retnet.utils.gutenberg import project_gutenberg_top_100_datapipe

torch.set_float32_matmul_precision("medium")
TOKENIZER = tiktoken.get_encoding("gpt2")
EVAL_PROMPT = "A Lannister always pays his debts."

def collate_fn(
    batch: List[str],
    max_length: int = 4096,
    device: Optional[Union[torch.device, str]] = None,
) -> Tuple[Tensor, Tensor]:
    x = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)
    y = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)

    for i, text in enumerate(batch):
        encoding = torch.as_tensor(
            TOKENIZER.encode(text), device=device, dtype=torch.long
        )
        seq_length = min(len(encoding) - 1, max_length)
        x[i, :seq_length] = encoding[:seq_length]
        y[i, :seq_length] = encoding[1 : seq_length + 1]

    return x, y

@dataclass
class TrainingState:
    fabric: Fabric
    model: LongNetLM
    optimizer: torch.optim.Optimizer
    callbacks: Sequence[Callable[["TrainingState", float], None]] = ()

    current_step: int = 0
    current_epoch: int = 0
    accumulate_grad_batches: int = 1
    monitor: str = "val_loss"
    monitor_mode: Literal["min", "max"] = "min"

@dataclass
class ModelCheckpoint:
    state_dict: Dict[str, Tensor]
    optimizer_state: Dict[str, Tensor]
    current_step: int
    current_epoch: int

    @classmethod
    def from_training_state(cls, state: TrainingState) -> "ModelCheckpoint":
        return cls(
            state_dict=state.model.state_dict(),
            optimizer_state=state.optimizer.state_dict(),
            current_step=state.current_step,
            current_epoch=state.current_epoch,
        )

    def to_dict(self) -> Dict[str, Any]:
        return {
            "state_dict": self.state_dict,
            "optimizer_state": self.optimizer_state,
            "current_step": self.current_step,
            "current_epoch": self.current_epoch,
        }

    def save(self, path: str) -> None:
        torch.save(self.to_dict(), path)

    @classmethod
    def load(cls, path: str) -> "ModelCheckpoint":
        checkpoint_dict = torch.load(path)
        return cls(**checkpoint_dict)

class CheckpointCallback:
    def __init__(
        self, save_dir: str, name: str = "checkpoint_epoch-{epoch:03d}.pt"
    ) -> None:
        self.save_dir = save_dir
        self.name = name
        self.best_path: Optional[str] = None
        self.best_loss: Optional[float] = None

    def __call__(self, state: TrainingState, loss: float) -> None:
        if self.best_loss is None:
            self.best_loss = loss

        fabric = state.fabric
        # 'local_rank == 0' means this only happens for the main process
        if fabric.local_rank == 0 and loss <= self.best_loss:
            checkpoint = ModelCheckpoint.from_training_state(state)
            self.best_loss = loss
            if self.best_path is not None:
                os.remove(self.best_path)
            self.best_path = os.path.join(
                self.save_dir, self.name.format(epoch=state.current_epoch)
            )
            torch.save(checkpoint, self.best_path)

        # All processes wait for main to finish saving the checkpoint.
        fabric.barrier()

def train_one_epoch(
    state: TrainingState,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    log_frequency: int = 25,
) -> None:
    state.current_epoch += 1
    fabric, model, optimizer = state.fabric, state.model, state.optimizer
    is_training = model.training
    model.train()

    with tqdm(desc=f"Ep: {state.current_epoch}") as progbar:
        train_loss, val_loss = 0.0, 0.0
        for x, y in train_dataloader:
            state.current_step += 1
            accumulating = state.current_step % state.accumulate_grad_batches != 0
            with fabric.no_backward_sync(model, enabled=accumulating):
                loss = model.forward(x)
                loss_value = loss.mean()
                fabric.backward(loss_value)

            if not accumulating:
                optimizer.step()
                optimizer.zero_grad()

            if state.current_step % log_frequency == 0:
                loss_scalar = loss.mean().item()  # Calculate the mean and extract the scalar value
                fabric.log("loss", loss_scalar, step=state.current_step)  # Log the scalar loss
                train_loss = loss_scalar  # Update the train_loss variable
                progbar.set_postfix_str(f"loss={train_loss:.4f}", refresh=False)
            progbar.update(1)

        model.eval()
        val_progbar = tqdm(desc="val", position=1, leave=False)
        for i, (x) in enumerate(val_dataloader):
            with torch.inference_mode():
                loss = model.forward(x)
            val_loss = (val_loss * i + loss.mean().item()) / (i + 1)

            if i % log_frequency == 0:
                val_progbar.set_postfix_str(f"val_loss={val_loss:.4f}", refresh=False)
            val_progbar.update(1)
            progbar.update(1)

        fabric.log("val_loss", val_loss, step=state.current_step)
        val_progbar.close()
        progbar.set_postfix_str(
            f"loss={train_loss:.4f}, val_loss={val_loss:.4f}", refresh=False
        )

        for callback in state.callbacks:
            callback(state, val_loss)

        model.train(mode=is_training)

def train(
    longnet: LongNetLM,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 10,
    lr: float = 3e-4,
    log_frequency: int = 25,
):
    if precision is None:
        if torch.cuda.is_available():
            # use bfloat16 if supported
            version, _ = torch.cuda.get_device_capability()
            precision = "bf16-mixed" if version >= 8 else "16-mixed"
        else:
            precision = "float32"

    logger = TensorBoardLogger(root_dir="./")
    fabric = Fabric(
        accelerator=accelerator,
        strategy=strategy,
        precision=precision,  # type: ignore
        loggers=[logger],
    )
    fabric.launch()
    print(f"Experiment version: {logger.version}")
    print("-" * 40)

    # Setup with fabric.
    optimizer = torch.optim.AdamW(longnet.parameters(), lr=lr)
    longnet, optimizer = fabric.setup(longnet, optimizer)
    train_dataloader, val_dataloader = fabric.setup_dataloaders(
        train_dataloader, val_dataloader
    )
    # Construct a training state and run the training loop.
    state = TrainingState(
        fabric=fabric,
        model=longnet,
        optimizer=optimizer,
        callbacks=[CheckpointCallback(save_dir=logger.log_dir)],
    )
    for _ in range(epochs):
        train_one_epoch(
            state=state,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            log_frequency=log_frequency,
        )

def generate(
    longnet: LongNet,
    prompt: str,
    prompt_chunk_size: Optional[int] = None,
    max_new_tokens: int = 4096,
    stop_tokens: Sequence[str] = (),
    top_k: int = 10,
    temperature: float = 1.0,
    seed: int = 42,
) -> Iterator[str]:
    seed_everything(seed)
    device = next(iter(longnet.parameters())).device
    is_training = longnet.training
    longnet.eval()

    # Tokenize the prompt and convert to a tensor.
    tokenized = TOKENIZER.encode(prompt)
    x = torch.as_tensor(tokenized, dtype=torch.long, device=device).unsqueeze_(0)

    if not prompt_chunk_size:
        prompt_chunk_size = x.size(1)

    prev_states: List[Optional[Tensor]] = [None] * longnet.num_layers
    start_idx: int = 0
    for start_idx in range(0, x.size(1), prompt_chunk_size):
        y, prev_states = longnet.forward(
            x, start_idx=start_idx, prev_states=prev_states
        )
        y = y[:, -1]

    # Generate tokens until we reach the maximum number of tokens or a stop token.
    for i in range(max_new_tokens):
        probs: Tensor = torch.softmax(y.squeeze() / max(temperature, 1e-8), dim=-1)
        # Get top-k tokens, renormalize their probabilities, and weighted sample.
        tokens: Tensor  # for mypy
        probs, tokens = probs.topk(k=top_k, dim=-1)
        probs /= probs.sum()

        # Take weighted random sample from the top-k tokens.
        sampled_idx: int = torch.multinomial(probs, num_samples=1).item()  # type: ignore
        token: int = tokens[sampled_idx].item()  # type: ignore
        tokenized.append(token)
        yield TOKENIZER.decode(tokenized)

        token_str: str = TOKENIZER.decode([token])
        if token_str in stop_tokens:
            break
        elif i < (max_new_tokens - 1):
            start_idx += 1
            x = torch.as_tensor([token], dtype=torch.long, device=device)
            y, prev_states = longnet.forward(
                x, start_idx, prev_states=prev_states
            )

    # Restore the model's original training state.
    longnet.train(mode=is_training)

def main(
    model_checkpoint: Optional[str] = None,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 10,
    batch_size: int = 16,
    lr: float = 3e-4,
    log_frequency: int = 25,
    seed: int = 42,
    eval_only: bool = False,
    eval_prompt: str = EVAL_PROMPT,
    eval_max_tokens: int = 1024,
):
    seed_everything(seed)
    # Create a (relatively small) model and dataloaders
    longnet = LongNetLM(
        num_tokens=TOKENIZER.n_vocab,
        d_model=768,
        nhead=12,
        num_encoder_layers=12,
        num_decoder_layers=12,
        dim_feedforward=3072,
        segment_lengths = [512,1024, 2048,4096],
        dilation_rates = [1, 2, 4, 6],
        dropout = 0.1,
        activation = F.relu,
        layer_norm_eps = 1e-5,
    )
    if model_checkpoint is not None:
        longnet.load_state_dict(ModelCheckpoint.load(model_checkpoint).state_dict)

    if not eval_only:
        train_dataloader = DataLoader(
            project_gutenberg_top_100_datapipe(
                split="train",
                chunk_size=4096,
                step_size=1024,
                shuffle=True,
                drop_last=True,
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
            drop_last=True,
        )
        val_dataloader = DataLoader(
            project_gutenberg_top_100_datapipe(
                split="val", chunk_size=4096, step_size=1024
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
        )

        train(
            longnet=longnet,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            accelerator=accelerator,
            strategy=strategy,
            precision=precision,
            epochs=epochs,
            lr=lr,
            log_frequency=log_frequency,
        )

    # Generate some text
    prev_output: str = ""
    for output in generate(longnet, eval_prompt, max_new_tokens=eval_max_tokens):
        # Return to the start of the line and print the output (no newline)
        print(output[len(prev_output) :], end="", flush=True)
        prev_output = output
    print()

# Define the default values or provide your desired values
default_model_checkpoint = None
default_accelerator = "auto"
default_strategy = "dp"
default_precision = None
default_epochs = 1
default_batch_size = 1
default_lr = 3e-4
default_log_frequency = 25
default_seed = 42
default_eval_only = False
default_eval_prompt = EVAL_PROMPT
default_eval_max_tokens = 1024

# Replace the argparse-related code
model_checkpoint = default_model_checkpoint
accelerator = default_accelerator
strategy = default_strategy
precision = default_precision
epochs = default_epochs
batch_size = default_batch_size
lr = default_lr
log_frequency = default_log_frequency
seed = default_seed
eval_only = default_eval_only
eval_prompt = default_eval_prompt
eval_max_tokens = default_eval_max_tokens

# Call the main function
main(
    model_checkpoint=model_checkpoint,
    accelerator=accelerator,
    strategy=strategy,
    precision=precision,
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    log_frequency=log_frequency,
    seed=seed,
    eval_only=eval_only,
    eval_prompt=eval_prompt,
    eval_max_tokens=eval_max_tokens
)
fkodom commented 1 year ago

@Akbarable Letting you know I see this issue. πŸ‘€ Having a very busy week, so it may be a few days before I can dedicate time to this. Will update you as soon as I am able to.

Akbarable commented 1 year ago

@Akbarable Letting you know I see this issue. πŸ‘€ Having a very busy week, so it may be a few days before I can dedicate time to this. Will update you as soon as I am able to.

Thanks for the response. No worries! 😁

andrewcchu commented 1 year ago

Bumping this as I'd also be interested in a training script -- tried the code @Akbarable posted here and saw the same behavior.