Open nrocketmann opened 3 months ago
Hi @nrocketmann
This is not how mixed precision works. What you describe is the expected behavior from what I can judge. I'm removing the bug label.
Here is an explanation of what mixed precision precision='16-mixed'
does:
https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#what-is-mixed-precision
Lightning does the same as what is referred to as AMP in PyTorch:
https://pytorch.org/docs/stable/amp.html#module-torch.amp
In short, the weights are kept in float32 precision, while supported operations are cast to 16-bit precision when appropriate. This will slightly use more memory but can often lead to a speed up. This is what mixed precision gets you.
If you want everything in 16-bit, including the weights, you can set precision="16-true"
or precision="bf16-true"
:
https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html#true-half-precision
You will have to be aware of numerical instability during training.
This failure is often silent (as in the example provided below), but can cause issues if a model contains a component (e.g. flash attention) that requires fp16.
If a layer in your model requires a specific dtype, you need to enforce that yourself (by up or downcasting).
Hi @awaelchli , thanks for getting back to me so quickly! Glad to hear this is expected behavior. I've looked at the docs you mentioned and it seems a few parts still don't make sense to me:
.half()
does not achieve this because it converts everything, including all the weights, to fp16. Or is it rather the case that fp16 mixed precision is only used for training models that should be run at inference time in full fp32 precision? I've found that when training in mixed precision then doing .half()
at inference time, my results are nonsensical.Mod
layer is outputting fp32 because it falls under the category of "not supported"?.to(torch.float16)
statements in my model, I would get back "Cannot unscale FP16 Gradients" errors. Is there another way to do this kind of casting?Thanks!
Bug description
I have an
nn.Module
(call itMod
) which adds its inputx
to an internalnn.Parameter
. I'm usingMod
as part of apl.LightningModule
which I'm training in16-mixed
precision. However, the output of callingMod
is a tensor with dtypetorch.float32
. When I use other layer types, they outputtorch.float16
tensors as expected. This failure is often silent (as in the example provided below), but can cause issues if a model contains a component (e.g. flash attention) that requires fp16. Furthermore, after loading a model trained this way at inference time and calling.half()
on it, the output isNaN
or otherwise nonsensical, despite being perfectly fine if I load the model in fp32.What version are you seeing the problem on?
v2.0
How to reproduce the bug
This is a small, reproducible example with
lightning==2.0.2
. Note how the output ofMod
has dtypetorch.float32
while the output of a linear layer has dtypetorch.float16
. The example runs distributed on 8 GPUs, but the issue is the same on a single GPU.Error messages and logs
Example output:
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer, LightningModule #- PyTorch Lightning Version (e.g., 1.5.0): 2.0.2 #- PyTorch Version (e.g., 2.0): 2.1.0 #- Python version (e.g., 3.9): 3.10.12 #- OS: Ubuntu 20.04.6 LTS (Focal Fossa) #- CUDA/cuDNN version: 11.8 #- GPU models and configuration: 8xA100 #- How you installed Lightning(`conda`, `pip`, source): pip ```More info
Thank you for your help!