Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.59k stars 1.05k forks source link

Weird error when using activation checkpointing for FSDPStrategy #805

Open RogerChern opened 11 months ago

RogerChern commented 11 months ago

I'm training tinyllama with 8 A40s. Everything goes very smooth until I want to increase the micro batch size for better computation to communication ratio.

I follow the official tutorial of lit gpt by passing activation_checkpointing_policy={Block} into FSDPStrategy. The modified setup is also attached below.

def setup(
    devices: int = 8,
    train_data_dir: Path = Path("data/redpajama_sample"),
    val_data_dir: Optional[Path] = None,
    precision: Optional[str] = None,
    tpu: bool = False,
    resume: Union[bool, Path] = False,
) -> None:
    precision = precision or get_default_supported_precision(training=True, tpu=tpu)

    if devices > 1:
        if tpu:
            ...
        else:
            strategy = FSDPStrategy(
                auto_wrap_policy={Block},
                activation_checkpointing_policy={Block},
                state_dict_type="full",
                limit_all_gathers=True,
                cpu_offload=False,
                sharding_strategy="FULL_SHARD",
            )
    else:
        strategy = "auto"

But I got some strange errors about the activation checkpointing. Could someone shed some light on this, anything informative is a big help for me.

Traceback (most recent call last):
  File "pretrain/tinyllama.py", line 424, in <module>
    CLI(setup)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/usr/local/lib/python3.8/dist-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "pretrain/tinyllama.py", line 108, in setup
    main(fabric, train_data_dir, val_data_dir, resume)
  File "pretrain/tinyllama.py", line 160, in main
    train(fabric, state, train_dataloader, val_dataloader, monitor, resume)
  File "pretrain/tinyllama.py", line 244, in train
    fabric.backward(loss / gradient_accumulation_steps)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/fabric.py", line 422, in backward
    self._strategy.backward(tensor, module, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/strategies/strategy.py", line 192, in backward
    self.precision.backward(tensor, module, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fsdp.py", line 126, in backward
    super().backward(tensor, model, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
    tensor.backward(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 1075, in unpack_hook
    frame.check_recomputed_tensors_match(gid)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 812, in check_recomputed_tensors_match
    raise CheckpointError(
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 27
Number of tensors saved during recomputation: 8
carmocca commented 11 months ago

Hi! The setup that you shared in your first snippet is very different to the setup in https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/tinyllama.py#L66. Can you share all changes that you made to the repo? You can do:

git diff > changes.diff

And then post the changes.diff file here.

cc @awaelchli in case you are familiar