Add Automatic Mixed Precision (AMP) support to CorrDiff training:
Add an fp_mode config parameter that can be one of {'fp32', 'fp16', 'amp'}. This supersedes the fp16 config parameter, thought the latter is still supported if fp_mode is not present. fp_mode == 'fp32' and fp_mode == 'fp16' correspond to fp16 == False and fp16 == True respectively, while fp_mode == 'amp' activates AMP.
Execute training forward pass in a torch.autocast environment that is enabled if fp_mode == 'amp'.
Disable certain datatype checks in U-Nets if AMP is enabled (these would otherwise cause an error).
Modulus Pull Request
Description
Add Automatic Mixed Precision (AMP) support to CorrDiff training:
fp_mode
config parameter that can be one of{'fp32', 'fp16', 'amp'}
. This supersedes thefp16
config parameter, thought the latter is still supported iffp_mode
is not present.fp_mode == 'fp32'
andfp_mode == 'fp16'
correspond tofp16 == False
andfp16 == True
respectively, whilefp_mode == 'amp'
activates AMP.torch.autocast
environment that is enabled iffp_mode == 'amp'
.Checklist
Dependencies