Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.57k stars 3.31k forks source link

Modules with `nn.Parameter` not Converted by Lightning Mixed Precision #19699

Open nrocketmann opened 3 months ago

nrocketmann commented 3 months ago

Bug description

I have an nn.Module (call it Mod) which adds its input x to an internal nn.Parameter. I'm using Mod as part of a pl.LightningModule which I'm training in 16-mixed precision. However, the output of calling Mod is a tensor with dtype torch.float32. When I use other layer types, they output torch.float16 tensors as expected. This failure is often silent (as in the example provided below), but can cause issues if a model contains a component (e.g. flash attention) that requires fp16. Furthermore, after loading a model trained this way at inference time and calling .half() on it, the output is NaN or otherwise nonsensical, despite being perfectly fine if I load the model in fp32.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

This is a small, reproducible example with lightning==2.0.2. Note how the output of Mod has dtype torch.float32 while the output of a linear layer has dtype torch.float16. The example runs distributed on 8 GPUs, but the issue is the same on a single GPU.

from lightning import pytorch as pl
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

class Mod(nn.Module):
    def __init__(self):
        super().__init__()
        derp = torch.randn((1, 32))
        self.p = nn.Parameter(derp, requires_grad=False)
    def forward(self, x):
        return x + self.p

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(32, 32)
        self.m = Mod()
        self.l = nn.MSELoss()
    def forward(self, x):
        print('x', x.dtype)
        y = self.lin(x)
        print('y', y.dtype)
        z = self.m(y)
        print('z', z.dtype)

        print('p',self.m.p.dtype)
        print('lin', self.lin.weight.dtype)
        return z
    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self(x)
        loss = self.l(z, y)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

xdata = torch.randn((1000, 32))
ydata = xdata + torch.randn_like(xdata) * .1
dataset=TensorDataset(xdata,ydata)
dataloader=DataLoader(dataset, batch_size=8, num_workers=4, pin_memory=True)
model = Model()
trainer = pl.Trainer(
    strategy='ddp',
    accelerator='gpu',
    devices=list(range(8)),
    precision='16-mixed'
)

trainer.fit(model=model, train_dataloaders=dataloader)

Error messages and logs

Example output:

Epoch 3:  78%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                               | 98/125 [00:00<00:00, 138.94it/s, v_num=5]x torch.float32
x torch.float32
y torch.float16
z torch.float32
p torch.float32
lin torch.float32

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer, LightningModule #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.2 #- PyTorch Version (e.g., 2.0): 2.1.0 #- Python version (e.g., 3.9): 3.10.12 #- OS: Ubuntu 20.04.6 LTS (Focal Fossa) #- CUDA/cuDNN version: 11.8 #- GPU models and configuration: 8xA100 #- How you installed Lightning(`conda`, `pip`, source): pip ```

More info

Thank you for your help!

awaelchli commented 3 months ago

Hi @nrocketmann

This is not how mixed precision works. What you describe is the expected behavior from what I can judge. I'm removing the bug label.

Here is an explanation of what mixed precision precision='16-mixed' does: https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#what-is-mixed-precision Lightning does the same as what is referred to as AMP in PyTorch: https://pytorch.org/docs/stable/amp.html#module-torch.amp

In short, the weights are kept in float32 precision, while supported operations are cast to 16-bit precision when appropriate. This will slightly use more memory but can often lead to a speed up. This is what mixed precision gets you.

If you want everything in 16-bit, including the weights, you can set precision="16-true" or precision="bf16-true": https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#true-half-precision You will have to be aware of numerical instability during training.

This failure is often silent (as in the example provided below), but can cause issues if a model contains a component (e.g. flash attention) that requires fp16.

If a layer in your model requires a specific dtype, you need to enforce that yourself (by up or downcasting).

nrocketmann commented 3 months ago

Hi @awaelchli , thanks for getting back to me so quickly! Glad to hear this is expected behavior. I've looked at the docs you mentioned and it seems a few parts still don't make sense to me:

Thanks!