Open AmitMY opened 3 weeks ago
@AmitMY The Trainer applies the PyTorch autocast context manager over the forward and converts the inputs. Take a look at the error traceback, see the line
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
and then from there work out which tensors (output, weights of TransformerEncoder) mismatch the dtype. It's possible that the input tensor here is the output of the previous layer (e.g. PositionalEncoding) and the dtype mismatch needs to be fixed there.
If there is reason to believe something is not done right in Lightning, please provide a reproducible example. Thanks!
The reason I believe it is a problem with pytorch-lightning
is that using normal torch autocasting works fine:
def test_training_step_bfloat16_expected_loss_finite(self):
batch = MaskedTensor(torch.full((4, 3, *self.pose_dim), fill_value=2, dtype=torch.float))
model = self.model_setup()
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
loss = model.training_step(batch)
self.assertNotEqual(0, float(loss))
self.assertTrue(torch.isfinite(loss))
As for the input to the transformer: both in torch autocast and lightning, I see:
dtype in PositionalEncoding torch.bfloat16 dtype out PositionalEncoding torch.float32
If I remove that layer, it still crashes with the same error.
Minimal repro:
import math
import pytorch_lightning as pl
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, IterableDataset
class PoseFSQAutoEncoder(nn.Module):
# pylint: disable=too-many-arguments
def __init__(self,
pose_dims: tuple = (178, 3),
hidden_dim=512,
nhead=16,
dim_feedforward=2048,
num_layers=6):
super().__init__()
self.encoder = nn.Sequential(
nn.Flatten(start_dim=2),
nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True),
num_layers=num_layers
)
)
def forward(self, batch: Tensor):
return self.encoder(batch)
class AutoEncoderLightningWrapper(pl.LightningModule):
def __init__(self, model: PoseFSQAutoEncoder,
learning_rate: float = 3e-4,
warmup_steps: int = 10000):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
def forward(self, batch):
return self.model(batch)
def configure_optimizers(self):
# Optimizer taken from https://arxiv.org/pdf/2307.09288.pdf
return torch.optim.AdamW(self.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.95),
eps=1e-5,
weight_decay=0.1)
def step(self, x: Tensor):
x_hat, indices = self(x)
# fake loss, for repro
return 0
def training_step(self, batch, *args, **kwargs):
loss, _ = self.step(batch)
return loss
def validation_step(self, batch, batch_idx, *args, **kwargs):
loss, prediction = self.step(batch)
return loss
class FakeDataset(IterableDataset):
def __iter__(self):
while True:
yield torch.randn(size=(10, 178, 3))
auto_encoder = PoseFSQAutoEncoder()
model = AutoEncoderLightningWrapper(auto_encoder)
train_dataset = DataLoader(FakeDataset(),
batch_size=2,
num_workers=0)
validation_dataset = DataLoader(FakeDataset(),
batch_size=2,
shuffle=False,
num_workers=0)
precision = "bf16-mixed"
trainer = pl.Trainer(max_steps=100000,
val_check_interval=100_000 // 2,
precision=precision,
)
trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=validation_dataset)
@AmitMY The error occurs during validation, so running training step under training conditions won't reveal the issue. Take a look at this PyTorch-only code snippet derived from your code, that shows that the transformer model has different behavior in eval mode:
# No lightning code involved below
import math
import torch
from torch import nn, Tensor
class PoseFSQAutoEncoder(nn.Module):
def __init__(self, pose_dims: tuple = (178, 3), hidden_dim=512, nhead=16, dim_feedforward=2048, num_layers=6):
super().__init__()
self.encoder = nn.Sequential(
nn.Flatten(start_dim=2),
nn.Linear(math.prod(pose_dims), hidden_dim, bias=False),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
),
num_layers=num_layers,
),
)
def forward(self, batch: Tensor):
return self.encoder(batch)
model = PoseFSQAutoEncoder()
batch = torch.randn(size=(2, 10, 178, 3))
model.eval() # <--- HERE: Different behavior .train() vs .eval()
with torch.no_grad():
with torch.autocast("cpu", dtype=torch.bfloat16):
model(batch)
As you can see, it produces the same error.
Bug description
bf16
precision in Trainer yields an errorWhat version are you seeing the problem on?
v2.3
How to reproduce the bug
My model includes this encoder:
Then, run the Trainer with
precision="bf16-mixed"
(Note! "bf16-true" works, but yields a very bad learning curve)
Error messages and logs
Environment
Current environment
* CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.11.2 - pytorch-lightning: 2.3.0 - torch: 2.2.2 - torchmetrics: 1.4.0.post0 - vector-quantize-pytorch: 1.14.24 * Packages: - aiohttp: 3.9.5 - aiosignal: 1.3.1 - astroid: 3.2.2 - attrs: 23.2.0 - certifi: 2024.6.2 - charset-normalizer: 3.3.2 - click: 8.1.7 - datasets: 2.20.0 - decorator: 4.4.2 - dill: 0.3.8 - docker-pycreds: 0.4.0 - einops: 0.8.0 - einx: 0.3.0 - filelock: 3.15.1 - frozendict: 2.4.4 - frozenlist: 1.4.1 - fsspec: 2024.5.0 - gitdb: 4.0.11 - gitpython: 3.1.43 - huggingface-hub: 0.23.3 - idna: 3.7 - imageio: 2.34.1 - imageio-ffmpeg: 0.5.1 - iniconfig: 2.0.0 - isort: 5.13.2 - jinja2: 3.1.4 - lightning-utilities: 0.11.2 - markupsafe: 2.1.5 - mccabe: 0.7.0 - moviepy: 1.0.3 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - networkx: 3.3 - numpy: 1.26.4 - opencv-python: 4.10.0.82 - packaging: 24.1 - pandas: 2.2.2 - pillow: 10.3.0 - pip: 24.0 - platformdirs: 4.2.2 - pluggy: 1.5.0 - pose-format: 0.4.1 - proglog: 0.1.10 - protobuf: 5.27.1 - psutil: 5.9.8 - pyarrow: 16.1.0 - pyarrow-hotfix: 0.6 - pylint: 3.2.3 - pytest: 8.2.2 - python-dateutil: 2.9.0.post0 - pytorch-lightning: 2.3.0 - pytz: 2024.1 - pyyaml: 6.0.1 - requests: 2.32.3 - scipy: 1.13.1 - sentry-sdk: 2.5.1 - setproctitle: 1.3.3 - setuptools: 69.5.1 - sign-vq: 0.0.1 - six: 1.16.0 - smmap: 5.0.1 - sympy: 1.12.1 - tomlkit: 0.12.5 - torch: 2.2.2 - torchmetrics: 1.4.0.post0 - tqdm: 4.66.4 - typing-extensions: 4.12.2 - tzdata: 2024.1 - urllib3: 2.2.1 - vector-quantize-pytorch: 1.14.24 - wandb: 0.17.1 - wheel: 0.43.0 - xxhash: 3.4.1 - yarl: 1.9.4 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.11.9 - release: 23.5.0 - version: Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000More info
I tried to follow https://github.com/Lightning-AI/pytorch-lightning/issues/15006 and feed the batch directly as
bf16
. that does not change the error