Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.17k stars 3.37k forks source link

[RFC] "auto" precision support #12895

Open justusschock opened 2 years ago

justusschock commented 2 years ago

🚀 Feature

Introducing the support for precision="auto", which enables AMP whenever available and falls back to fp32 otherwise.

Motivation

After introducing accelerator="auto" support, it is hard to determine the precision type to use. Personally, I use AMP, whenever running on GPU and the GPU supports it. Otherwise I use fp32. Since precision has to be set at the same time, we specify the accelerator, with accelerator="auto" we cannot easily at that time determine on which accelerator it will be running (well we could, but that would be kind of duplicating logic we use for accelerator-detection internally) without a larger amount of boilerplate.

Pitch

Have precision="auto" to switch between AMP and fp32 depending on the accelerator and whether the accelerator on hand does support AMP.

Alternatives

Manually write out all the boilerplate.


If you enjoy Lightning, check out our other projects! âš¡

cc @borda @tchaton @justusschock @awaelchli @carmocca @akihironitta @rohitgr7 @kaushikb11 @ananthsub

carmocca commented 2 years ago

I see the usefulness of this in the general base case. However, it opens a can of worms in terms of unconventional accelerators, strategies, and your actual hardware (the precision choice depends on the hardware too).

There will be combinations of these where we won't be able to make an informed decision and we will likely need to default to regular precision. In those cases, setting auto for precision might be misleading.

justusschock commented 2 years ago

@carmocca I think that we will still default to fp32 and users can set this. If not supported by hardware we can simply error out, I guess.

However, I really think that we need a way to say "use amp if on gpu and else use fp32" :)