Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.48k stars 3.3k forks source link

autocast to float16/bfloat16 fails on transformer encoder #19980

Open AmitMY opened 3 weeks ago

AmitMY commented 3 weeks ago

Bug description

bf16 precision in Trainer yields an error

What version are you seeing the problem on?

v2.3

How to reproduce the bug

My model includes this encoder:

        self.encoder = nn.Sequential(
            nn.Flatten(start_dim=2),
            nn.Dropout(0.15),
            nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
            PositionalEncoding(d_model=hidden_dim),
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
                                           dim_feedforward=dim_feedforward,
                                           batch_first=True),
                num_layers=num_layers
            )
        )

Then, run the Trainer with precision="bf16-mixed"

(Note! "bf16-true" works, but yields a very bad learning curve)

Error messages and logs

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/train.py", line 147, in <module>
    main()
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/train.py", line 143, in main
    trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=validation_dataset)
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
    self._run_sanity_check()
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
    val_loop.run()
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 234, in validation_step
    loss, prediction = self.step(batch)
                       ^^^^^^^^^^^^^^^^
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 215, in step
    x_hat, indices = self(x)
                     ^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 170, in forward
    return self.model(batch)
           ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amitmoryossef/dev/sign-language-processing/vq/sign_vq/model.py", line 129, in forward
    x = self.encoder(x)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 391, in forward
    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/envs/vq/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 685, in forward
    return torch._transformer_encoder_layer_fwd(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.11.2 - pytorch-lightning: 2.3.0 - torch: 2.2.2 - torchmetrics: 1.4.0.post0 - vector-quantize-pytorch: 1.14.24 * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - astroid: 3.2.2 - attrs: 23.2.0 - certifi: 2024.6.2 - charset-normalizer: 3.3.2 - click: 8.1.7 - datasets: 2.20.0 - decorator: 4.4.2 - dill: 0.3.8 - docker-pycreds: 0.4.0 - einops: 0.8.0 - einx: 0.3.0 - filelock: 3.15.1 - frozendict: 2.4.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - gitdb: 4.0.11 - gitpython: 3.1.43 - huggingface-hub: 0.23.3 - idna: 3.7 - imageio: 2.34.1 - imageio-ffmpeg: 0.5.1 - iniconfig: 2.0.0 - isort: 5.13.2 - jinja2: 3.1.4 - lightning-utilities: 0.11.2 - markupsafe: 2.1.5 - mccabe: 0.7.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - networkx: 3.3 - numpy: 1.26.4 - opencv-python: 4.10.0.82 - packaging: 24.1 - pandas: 2.2.2 - pillow: 10.3.0 - pip: 24.0 - platformdirs: 4.2.2 - pluggy: 1.5.0 - pose-format: 0.4.1 - proglog: 0.1.10 - protobuf: 5.27.1 - psutil: 5.9.8 - pyarrow: 16.1.0 - pyarrow-hotfix: 0.6 - pylint: 3.2.3 - pytest: 8.2.2 - python-dateutil: 2.9.0.post0 - pytorch-lightning: 2.3.0 - pytz: 2024.1 - pyyaml: 6.0.1 - requests: 2.32.3 - scipy: 1.13.1 - sentry-sdk: 2.5.1 - setproctitle: 1.3.3 - setuptools: 69.5.1 - sign-vq: 0.0.1 - six: 1.16.0 - smmap: 5.0.1 - sympy: 1.12.1 - tomlkit: 0.12.5 - torch: 2.2.2 - torchmetrics: 1.4.0.post0 - tqdm: 4.66.4 - typing-extensions: 4.12.2 - tzdata: 2024.1 - urllib3: 2.2.1 - vector-quantize-pytorch: 1.14.24 - wandb: 0.17.1 - wheel: 0.43.0 - xxhash: 3.4.1 - yarl: 1.9.4 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.11.9 - release: 23.5.0 - version: Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000

More info

I tried to follow https://github.com/Lightning-AI/pytorch-lightning/issues/15006 and feed the batch directly as bf16. that does not change the error

awaelchli commented 3 weeks ago

@AmitMY The Trainer applies the PyTorch autocast context manager over the forward and converts the inputs. Take a look at the error traceback, see the line

    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)

and then from there work out which tensors (output, weights of TransformerEncoder) mismatch the dtype. It's possible that the input tensor here is the output of the previous layer (e.g. PositionalEncoding) and the dtype mismatch needs to be fixed there.

If there is reason to believe something is not done right in Lightning, please provide a reproducible example. Thanks!

AmitMY commented 3 weeks ago

The reason I believe it is a problem with pytorch-lightning is that using normal torch autocasting works fine:

    def test_training_step_bfloat16_expected_loss_finite(self):
        batch = MaskedTensor(torch.full((4, 3, *self.pose_dim), fill_value=2, dtype=torch.float))
        model = self.model_setup()

        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
            loss = model.training_step(batch)
        self.assertNotEqual(0, float(loss))
        self.assertTrue(torch.isfinite(loss))

As for the input to the transformer: both in torch autocast and lightning, I see:

dtype in PositionalEncoding torch.bfloat16 dtype out PositionalEncoding torch.float32

If I remove that layer, it still crashes with the same error.

Minimal repro:

import math

import pytorch_lightning as pl
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, IterableDataset

class PoseFSQAutoEncoder(nn.Module):
    # pylint: disable=too-many-arguments
    def __init__(self,
                 pose_dims: tuple = (178, 3),
                 hidden_dim=512,
                 nhead=16,
                 dim_feedforward=2048,
                 num_layers=6):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Flatten(start_dim=2),
            nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
                                           dim_feedforward=dim_feedforward,
                                           batch_first=True),
                num_layers=num_layers
            )
        )

    def forward(self, batch: Tensor):
        return self.encoder(batch)

class AutoEncoderLightningWrapper(pl.LightningModule):
    def __init__(self, model: PoseFSQAutoEncoder,
                 learning_rate: float = 3e-4,
                 warmup_steps: int = 10000):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps

    def forward(self, batch):
        return self.model(batch)

    def configure_optimizers(self):
        # Optimizer taken from https://arxiv.org/pdf/2307.09288.pdf
        return torch.optim.AdamW(self.parameters(),
                                 lr=self.learning_rate,
                                 betas=(0.9, 0.95),
                                 eps=1e-5,
                                 weight_decay=0.1)

    def step(self, x: Tensor):
        x_hat, indices = self(x)

        # fake loss, for repro
        return 0

    def training_step(self, batch, *args, **kwargs):
        loss, _ = self.step(batch)
        return loss

    def validation_step(self, batch, batch_idx, *args, **kwargs):
        loss, prediction = self.step(batch)
        return loss

class FakeDataset(IterableDataset):
    def __iter__(self):
        while True:
            yield torch.randn(size=(10, 178, 3))

auto_encoder = PoseFSQAutoEncoder()
model = AutoEncoderLightningWrapper(auto_encoder)

train_dataset = DataLoader(FakeDataset(),
                           batch_size=2,
                           num_workers=0)
validation_dataset = DataLoader(FakeDataset(),
                                batch_size=2,
                                shuffle=False,
                                num_workers=0)

precision = "bf16-mixed"
trainer = pl.Trainer(max_steps=100000,
                     val_check_interval=100_000 // 2,
                     precision=precision,
                     )

trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=validation_dataset)
awaelchli commented 1 week ago

@AmitMY The error occurs during validation, so running training step under training conditions won't reveal the issue. Take a look at this PyTorch-only code snippet derived from your code, that shows that the transformer model has different behavior in eval mode:

# No lightning code involved below
import math

import torch
from torch import nn, Tensor

class PoseFSQAutoEncoder(nn.Module):
    def __init__(self, pose_dims: tuple = (178, 3), hidden_dim=512, nhead=16, dim_feedforward=2048, num_layers=6):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Flatten(start_dim=2),
            nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
                ),
                num_layers=num_layers,
            ),
        )

    def forward(self, batch: Tensor):
        return self.encoder(batch)

model = PoseFSQAutoEncoder()
batch = torch.randn(size=(2, 10, 178, 3))

model.eval()  # <--- HERE: Different behavior .train() vs .eval()

with torch.no_grad():
    with torch.autocast("cpu", dtype=torch.bfloat16):
        model(batch)

As you can see, it produces the same error.