Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.47k stars 1.04k forks source link

Multi-GPU training #328

Closed srikanthsrnvs closed 1 year ago

srikanthsrnvs commented 1 year ago

I have two questions around the current codebase as it stands.

  1. How does chunked cross entropy loss work? How do we scale the value we get from that to ordinary Cross entropy loss?
  2. When using the finetune/lora.py code to train on 8 GPUs, I see that the validate() function is called 8 times on 8 separate processes. Why is that? Why does fabric run model validation & training on each process separately? As far as I understand, when using DDP, the model is cloned across devices, using a single shared device as the manager for the gradient updates, or am I incorrect? If so, why do all 8 processes spawn training output & validation outputs? Shouldn't there only be 1?
  3. The last question is around an OOM. Currently, there is a weird bug where using lora.py results in an OOM on a batch size of 1 for a 13B parameter llama model, with a context length of 4096. The training works fine, until it hits a validation interval, at which the OOM gets triggered. Here is my current lora.py
def setup(
    data: Path,
    size: str = "13b",
    out_dir: Path = Path("/scratch/data/models/finetuned"),
    devices: int = torch.cuda.device_count(),
    precision: Optional[str] = None,
    tpu: bool = False,
    context_length: int = 4096,
    log_interval: int = 10
):
    checkpoint_dir: Path = Path(f"/scratch/data/meta-llama/Llama-2-{size}-hf/")

    if precision is None:
        precision = "32-true" if tpu else "bf16-mixed"
    fabric_devices = devices
    if fabric_devices > 1:
        if tpu:
            # For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
            fabric_devices = "auto"
            strategy = XLAStrategy(sync_module_states=False)
        else:
            precision="bf16-true"
            strategy=FSDPStrategy(
                auto_wrap_policy={Block},
                activation_checkpointing_policy={Block},
                state_dict_type="full",
                limit_all_gathers=True,
            )
    else:
        strategy = "auto"

    print("Launching Fabric...")
    logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
    fabric = L.Fabric(devices=fabric_devices, precision=precision, loggers=logger)
    fabric.print(hparams)
    main(fabric, data, checkpoint_dir, out_dir, context_length)

def main(
        fabric: L.Fabric,
        data: Path,
        checkpoint_dir: Path, 
        out_dir: Path,
        context_length: int,
        learning_rate: float = 3e-5,
        weight_decay: float = 0.01,
        batch_size: int = 1,
        epochs: int = 3,
        warmup: int = 20,
        gradient_accumulation_steps: int = 8,
        log_interval: int = 250,
        eval_interval: int = 50,
        save_interval: int = 50,
    ):
    check_valid_checkpoint_dir(checkpoint_dir)

    # Load the checkpoint director and make the dataframes
    data_dir = Path("/scratch/data/datasets/supervised")
    # if fabric.local_rank == 0:
    #     print("Preprocessing data...")
    #     data_frame = pd.read_csv(data)
    #     data_dir = prepare_supervised(data_frame, checkpoint_dir, context_length=context_length)

    fabric.barrier()

    speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")

    # fabric.seed_everything(1337)  # same seed for every process to init model (FSDP)

    if fabric.global_rank == 0:
        os.makedirs(out_dir, exist_ok=True)

    train_data = torch.load(data_dir / "train.pt")
    val_data = torch.load(data_dir / "test.pt")

    LORA_HP = {
        "r": 8,
        "alpha": 16,
        "dropout": 0.05,
        "query": True,
        "key": False,
        "value": True,
        "projection": False,
        "mlp": False,
        "head": False,
    }

    if not any((LORA_HP["query"], LORA_HP["key"], LORA_HP["value"], LORA_HP["projection"], LORA_HP["mlp"], LORA_HP["head"])):
        fabric.print("Warning: all LoRA layers are disabled!")
    config = Config.from_name(
        name=checkpoint_dir.name,
        r=LORA_HP["r"],
        alpha=LORA_HP["alpha"],
        dropout=LORA_HP["dropout"],
        to_query=LORA_HP["query"],
        to_key=LORA_HP["key"],
        to_value=LORA_HP["value"],
        to_projection=LORA_HP['projection'],
        to_mlp=LORA_HP['mlp'],
        to_head=LORA_HP["head"],
    )
    checkpoint_path = checkpoint_dir / "lit_model.pth"
    fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")

    with fabric.init_module(empty_init=True):
        model = GPT(config)
        model.apply(model._init_weights)  # for the LoRA weights

    print("Loading state dict...")
    fabric.load_raw(checkpoint_path, model, strict=False)

    # with lazy_load(checkpoint_path) as checkpoint:
        # strict=False because missing keys due to LoRA weights not contained in state dict
        # model.load_state_dict(checkpoint, strict=False)

    mark_only_lora_as_trainable(model)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    num_params = sum(p.numel() for p in trainable_params)
    fabric.print(f"Number of trainable parameters: {num_params:,}")
    num_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    fabric.print(f"Number of non trainable parameters: {num_params:,}")

    optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
    model, optimizer = fabric.setup(model, optimizer)

    # fabric.seed_everything(1337 + fabric.global_rank)

    print("Training setup...")
    train_time = time.time()
    train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir, speed_monitor, batch_size, epochs, warmup, learning_rate, gradient_accumulation_steps, log_interval, eval_interval, save_interval)
    fabric.print(f"Training time: {(time.time()-train_time):.2f}s")

    # Save the final LoRA checkpoint at the end of training
    save_path = out_dir / "finetuned.pth"
    save_lora_checkpoint(fabric, model, save_path)

def train(
    fabric: L.Fabric,
    model: GPT,
    optimizer: torch.optim.Optimizer,
    train_data: List[Dict],
    val_data: List[Dict],
    checkpoint_dir: Path,
    out_dir: Path,
    speed_monitor: SpeedMonitor,
    batch_size: int,
    epochs: int,
    warmup_steps: int,
    learning_rate: float,
    gradient_accumulation_steps: int,
    log_interval: int,
    eval_interval: int,
    save_interval: int

) -> None:
    tokenizer = Tokenizer(checkpoint_dir)
    max_seq_length, longest_seq_length, longest_seq_ix = get_max_seq_length(train_data)
    fabric.barrier()

    model.eval()
    val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length, batch_size=64)  # sanity check

    fabric.print(f"Val loss {val_loss:.4f}")

    with torch.device("meta"):
        meta_model = GPT(model.config)
        # estimated is too much of an optimistic estimate, left just for reference
        estimated_flops = estimate_flops(meta_model) * batch_size
        fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
        x = torch.randint(0, 1, (batch_size, model.config.block_size))
        measured_flops = measure_flops(meta_model, x)
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
        del meta_model, x

    step_count = 0
    total_lengths = 0
    iterations = epochs * len(train_data) // batch_size
    total_t0 = time.time()

    if fabric.device.type == "xla":
        import torch_xla.core.xla_model as xm

        xm.mark_step()

    if fabric.global_rank == 0:
        wandb.init(project="dreamshow-lora-torch", name="0.0.3")
        wandb.watch(model)

    fabric.barrier()
    model.train()

    for iter_num in tqdm(range(iterations), desc="Begun training"):
        if step_count <= warmup_steps:
            # linear warmup
            lr = learning_rate * step_count / warmup_steps
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

        iter_t0 = time.time()

        input_ids, targets = get_batch(
            fabric, train_data, longest_seq_length, batch_size, longest_seq_ix if iter_num == 0 else None
        )

        is_accumulating = (iter_num + 1) % gradient_accumulation_steps != 0
        with fabric.no_backward_sync(model, enabled=is_accumulating):
            logits = model(input_ids, max_seq_length=max_seq_length, lm_head_chunk_size=0)
            # shift the targets such that output n predicts token n+1
            # logits[-1] = logits[-1][..., :-1, :]
            loss = loss_fn(logits, targets)
            fabric.backward(loss / gradient_accumulation_steps)

        if not is_accumulating:
            optimizer.step()
            optimizer.zero_grad()
            step_count += 1
        elif fabric.device.type == "xla":
            xm.mark_step()

        t1 = time.time()
        total_lengths += input_ids.size(1)
        speed_monitor.on_train_batch_end(
            (iter_num + 1) * batch_size,
            t1 - total_t0,
            # this assumes that device FLOPs are the same and that all devices have the same batch size
            fabric.world_size,
            flops_per_batch=measured_flops,
            lengths=total_lengths,
        )
        if iter_num % log_interval == 0:
            fabric.print(
                f"iter {iter_num} step {step_count}: loss {loss.item():.4f}, iter time:"
                f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
            )
            if fabric.global_rank == 0:
                wandb.log({"train_loss": loss.item(), "train_step": step_count})

        if not is_accumulating and step_count % eval_interval == 0:
            t0 = time.time()
            val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length, batch_size=64)
            t1 = time.time() - t0
            speed_monitor.eval_end(t1)
            fabric.print(f"step {iter_num}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
            fabric.barrier()
            if fabric.global_rank == 0:
                wandb.log({"eval_loss": loss.item(), "eval_step": step_count})

        if not is_accumulating and step_count % save_interval == 0:
            checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
            save_lora_checkpoint(fabric, model, checkpoint_path)

@torch.no_grad()
def validate(
    fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, longest_seq_length: int, batch_size: int = 1
) -> torch.Tensor:
    model.eval()

    eval_iterations = len(val_data) // batch_size
    losses = torch.zeros(eval_iterations)
    for k in tqdm(range(eval_iterations), desc="Validating"):
        input_ids, targets = get_batch(fabric, val_data, longest_seq_length, batch_size)
        logits = model(input_ids)
        loss = loss_fn(logits, targets)
        losses[k] = loss.item()
    val_loss = losses.mean()

    model.reset_cache()

    model.train()
    return val_loss.item()

def get_batch(
    fabric: L.Fabric, data: List[Dict], longest_seq_length: int,  batch_size: int, longest_seq_ix: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    ix = torch.randint(len(data), (batch_size,))
    if longest_seq_ix is not None:
        # force the longest sample at the beginning so potential OOMs happen right away
        ix[0] = longest_seq_ix

    input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
    labels = [data[i]["labels"].type(torch.int64) for i in ix]

    # it's better to pad to a fixed seq length with XLA to avoid recompilation
    max_len = max(len(s) for s in input_ids) if fabric.device.type != "xla" else longest_seq_length

    def pad_right(x, pad_id):
        # pad right based on the longest sequence
        n = max_len - len(x)
        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])

    if fabric.device.type == "cuda" and x.device.type == "cpu":
        x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
    else:
        x, y = fabric.to_device((x, y))
    return x, y

def get_max_seq_length(data: List[Dict]) -> Tuple[int, int, int]:
    # find out the minimum max_seq_length required during fine-tuning (saves memory!)
    lengths = [len(d["input_ids"]) for d in data]
    max_seq_length = max(lengths)
    longest_seq_ix = lengths.index(max_seq_length)
    # support easy override at the top of the file
    return (
        override_max_seq_length if isinstance(override_max_seq_length, int) else max_seq_length,
        max_seq_length,
        longest_seq_ix,
    )
awaelchli commented 1 year ago

Hi @srikanthsrnvs

Thanks for your interest in the library and your questions.

  1. The chunking in the cross entropy loss is quite simple: Instead of computing the cross entropy over the entire batch and sequence, the computation is done on smaller chunks (splitting along the batch dimension) to avoid memory peaks at the cost of a slightly slower speed. The returned loss should be the same as the non-chunked version, there is no scaling needed or anything by the user.

  2. The behavior you describe is correct and is expected. In multi-gpu training with N gpus, there will be N processes running in parallel, each one assigned to one GPU. There is no single shared "manager" device. Perhaps you are mixing it up with some other form of training, but this is not how it is in Lightning / PyTorch.

  3. I'll run this and let you know what I find, but has to wait until monday unfortunately. Will get back to you then. Based alone on your description, I suspect that the sequence length here is the big factor that pushes you to OOM. Our finetuning scripts and the numbers we documented in the readmes were for much smaller sequence lengths (because of the finetuning data which had shorter promts).

srikanthsrnvs commented 1 year ago

Gotcha, thanks.

  1. Based off your response, the loss computed should be the same as without chunking the loss. I tried this, and got wildly different loss values. I'm getting validation loss somewhere in the 14-15 range, and my train loss somewhere 0.01-0.05 range, which makes no sense to me. I am probably using it wrong but not sure how.
  2. If I wanted a naive implementation of weights and biases, should I just do a if fabric.global_rank==0: wandb.log()? Since I dont want to reinitialize wandb on every process.
awaelchli commented 1 year ago

@srikanthsrnvs

1) Could you show for which inputs the function returns wildly different values? I extended the test case in #343 to show that the chunked implementation is equivalent to the regular cross entripy loss.

2) Yes :)