Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

ONNX export doesn't work with BFloat16 #19337

Closed pfeatherstone closed 8 months ago

pfeatherstone commented 8 months ago

Bug description

Trying to export a model with BFloat16 weights breaks. Usually, lightning transfers the inputs to the correct device, and i thought, the correct dtype. That might not be the case.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os
import torch
import lightning.pytorch as pl

class RandomDataset(torch.utils.data.Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class CallExport(pl.callbacks.Callback):
    def on_validation_end(self, trainer, pl_module: BoringModel):
        if trainer.is_global_zero:
            print("Running ONNX export")
            x = torch.randn(10, 32)
            pl_module.to_onnx('/tmp/model.onnx', input_sample=(x,))
        trainer.strategy.barrier()

def run():
    train_data  = torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data    = torch.utils.data.DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        precision="bf16-true",
        callbacks=[CallExport()],
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

if __name__ == "__main__":
    run()

Error messages and logs

RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Environment

Current environment * CUDA: - GPU: - NVIDIA RTX A6000 - available: True - version: 12.1 * Lightning: - lightning: 2.1.3 - lightning-utilities: 0.10.0 - pytorch-lightning: 2.1.2 - recurrent-memory-transformer-pytorch: 0.5.5 - torch: 2.1.2 - torchaudio: 2.1.2 - torchmetrics: 1.2.0 - torchvision: 0.16.2 * Packages: - absl-py: 2.0.0 - aiofiles: 23.2.1 - aiohttp: 3.9.1 - aiosignal: 1.3.1 - argparse: 1.4.0 - async-timeout: 4.0.3 - attrs: 23.1.0 - cachetools: 5.3.2 - certifi: 2023.11.17 - charset-normalizer: 3.3.2 - coloredlogs: 15.0.1 - contourpy: 1.2.0 - cycler: 0.12.1 - datasets: 2.16.1 - dill: 0.3.7 - einops: 0.7.0 - filelock: 3.13.1 - flashlight: 0.1.1 - flashlight-text: 0.0.4 - flatbuffers: 23.5.26 - fonttools: 4.45.1 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - future: 0.18.3 - google-auth: 2.23.4 - google-auth-oauthlib: 1.1.0 - grpcio: 1.59.3 - httptools: 0.6.1 - huggingface-hub: 0.20.1 - humanfriendly: 10.0 - idna: 3.6 - jaxtyping: 0.2.23 - jinja2: 3.1.2 - kiwisolver: 1.4.5 - lightning: 2.1.3 - lightning-utilities: 0.10.0 - llvmlite: 0.41.1 - mako: 1.3.0 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib: 3.8.2 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - multiprocess: 0.70.15 - networkx: 3.2.1 - numba: 0.58.1 - numpy: 1.26.2 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.18.1 - nvidia-nvjitlink-cu12: 12.3.101 - nvidia-nvtx-cu12: 12.1.105 - oauthlib: 3.2.2 - onnx: 1.15.0 - onnxruntime: 1.16.3 - onnxruntime-gpu: 1.16.3 - onnxscript: 0.1.0.dev20231213 - onnxsim: 0.4.35 - packaging: 23.2 - pandas: 2.1.4 - pillow: 10.1.0 - pip: 22.0.2 - pipe: 2.0 - protobuf: 4.23.4 - pyarrow: 14.0.2 - pyarrow-hotfix: 0.6 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pybombs: 2.3.5 - pygments: 2.17.2 - pyparsing: 3.1.1 - pyqt5: 5.15.10 - pyqt5-qt5: 5.15.2 - pyqt5-sip: 12.13.0 - pyqtgraph: 0.13.3 - python-dateutil: 2.8.2 - pytorch-lightning: 2.1.2 - pytz: 2023.3.post1 - pyyaml: 6.0.1 - recurrent-memory-transformer-pytorch: 0.5.5 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - rich: 13.7.0 - rsa: 4.9 - ruamel.yaml: 0.18.5 - ruamel.yaml.clib: 0.2.8 - sanic: 0.7.0 - scipy: 1.11.4 - setuptools: 59.6.0 - six: 1.16.0 - sympy: 1.12 - tensorboard: 2.15.1 - tensorboard-data-server: 0.7.2 - torch: 2.1.2 - torchaudio: 2.1.2 - torchmetrics: 1.2.0 - torchvision: 0.16.2 - tqdm: 4.66.1 - triton: 2.1.0 - typeguard: 2.13.3 - typing-extensions: 4.8.0 - tzdata: 2023.4 - uhd: 4.6.0 - ujson: 5.9.0 - urllib3: 2.1.0 - uvloop: 0.19.0 - websockets: 12.0 - werkzeug: 3.0.1 - x-transformers: 1.27.9 - xxhash: 3.4.1 - yarl: 1.9.3 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 6.2.0-33-generic - version: #33~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 7 10:33:52 UTC 2

More info

No response

awaelchli commented 8 months ago

@pfeatherstone This is because the input is float32, yet the model weights are in float16. To fix this, you need to create the input with the right precision:

x = torch.randn(10, 32, dtype=torch.bfloat16)

I don't think this can be handled automatically by Lightning.