pytorch / torchdistx

Torch Distributed Experimental
BSD 3-Clause "New" or "Revised" License
116 stars 31 forks source link

[AnyPrecision optimizer] consider FP32 defaults, possibly automated via BF16 support check #59

Open lessw2020 opened 2 years ago

lessw2020 commented 2 years ago

Enhancement (credit to @rohan-varma): "this can be done in a follow up PR, but let's maybe consider not defaulting things to torch.bfloat16 eventually. this is because it might be good to make this optimizer usable out of the box with the defaults on all HW architectures, but only A100 supports bfloat16 well at the moment.

But the downside here would be that the default optimizer won't be too interesting, it'd just be AdamW"

Possible option to accomplish this would be a simple bf16 native support check, and then revert any BF16 defaults to FP32 (and turn off Kahan as well since it would not add benefit). Downside dilemma is if you should warn user about this change - positive they know they are not getting BF16 benefits, negative is they may have been aware and don't enjoy one line warning * 128 gpus.

lessw2020 commented 2 years ago

from review discussion: "actually I think having this concept of 'smart defaults' where it attempts appropriate BF16 but rolls back to FP32 when not supported is a nice user experience. Could also apply for kahan summation being turned on and we revert it back to off, if BF1t6 is not supported and enable both momentum and variance to FP32 automatically (since Kahan adds no value for FP32). This would be nice as they would inherently get the optimal setup on various hardware.

That also gives it a bit more value even in the scenario of it becomes AdamW b/c it also reverts automatically on future hardward where BF16 is supported."