Azure / MS-AMP

Microsoft Automatic Mixed Precision Library
https://azure.github.io/MS-AMP/
MIT License
483 stars 35 forks source link

Questions: Clarifying the use of FP8 for Training #99

Open jon-chuang opened 10 months ago

jon-chuang commented 10 months ago

@tocean @wkcn

In line with the investigation in https://github.com/NVIDIA/TransformerEngine/issues/424, it would be great to get the insights from the team at microsoft for using FP8 in aspects of training besides matmul.

Questions

1. Performance

The repo only mention training accuracy and memory savings. However, the kernels may not be very optimized and majority is implemented in Torch. I guess that performance is still unexplored.

2. Weight Update

3. More Accurate Scaling Factors

Is there a way to maintain more accurate amax by estimating:

4. Adaptive Precision

Has it been explored using lower precision (FP8) at high learning rate (at earlier epochs) and higher precision (e.g. FP32, FP16) at lower learning rate (at later epochs)?

Appendix

Update Rule for sqrt_v_fp8

scaling_factor = amax_sqrt_v_prev / 448. # 2^8 * (1 + 3/4) - use more of fp8e5m2 dynamic range. margin = 7

v_fp32 = pow2(sqrt_v_fp8.to(dtype.fp32) * scaling_factor)
v_new = beta_2 * v_fp32 + (1 - beta_2) * grad_sq
sqrt_v_fp8 = (sqrt(v_new) / scaling_factor).to(dtype.fp8e5m2)

# end of loop
amax_sqrt_v_new = sqrt(max(v_new))

Notes:

  1. If amax_sqrt_v_fp8 = 448.0, then the scaling factor is 1. This is captured in margin bits: https://github.com/Azure/MS-AMP/blob/aed29d6533e8ff86686fdb6fafa6d0b720e9e5f6/msamp/common/tensor/meta.py#L39
wkcn commented 9 months ago

Hi @jon-chuang , I am sorry for late reply. Thanks for your attention to our work!

1. Performance

The repo only mention training accuracy and memory savings. However, the kernels may not be very optimized and majority is implemented in Torch. I guess that performance is still unexplored.

Yes. Our first step is to apply FP8 format as much as possible to reduce memory footprint while maintaining accuracy, and the second step is to optimize the performance in MS-AMP. MS-AMP can be combined with the TransformerEngine to invoke optimized operators in TE. (Related PR: https://github.com/Azure/MS-AMP/pull/98)

2. Weight Update

a) It is applied after the entire backward pass is complete. The FP8 weights are updated in the optimizer. (https://github.com/Azure/MS-AMP/blob/main/msamp/optim/adamw.py#L193).

b) Good idea. I had tried using an additional CUDA stream for weight update, but it did not achieve the dessired acceleration, probably due to my implementation not being optimal : ) However, I still believe that it is effective to schedule weight updates concurrently, since weight update does not affect the calculation of backpropagation.

It is available to update multiple FP8 weights in a single CUDA kernel, but it is notable that the FP8 tensor with a scaling factor should be treated as a whole. The maximum value of the entire tensor should be computed before quantization a high-precision tensor to a FP8 tensor.

3. More Accurate Scaling Factors

Is there a way to maintain more accurate amax by estimating:

  • For e.g. naive SGD case:
    • scaling_factor_weights_t = amax_weights_t-1 + amax_grad_t - this is an accurate upper bound (no necessity of apriori knowledge)
    • amax_weights_t = max(abs(weights_t)) - this is only used for the next iteration
  • For Adam optimizer:
    • Utilizing e5m2 might be able to help with dynamic range for v (same dynamic range as FP16).
    • Storing sqrt_v rather than v may help the precision. Update rule: see appendix
      • Intuition: sqrt will reduce the dynamic range of bits by half (2^16 -> 2^8, 2^-16 -> 2^-8). Hence we perform sqrt in fp32/fp16 and quantize that as fp8, thus preserving the dynamic range
    • A more rigorous analysis is needed here.
  • If it is possible to better estimate scaling_factor_weights_t then it may be possible to use more of the dynamic range. Hence, storing the weights as FP8 (rather than FP16 as in the MS-AMP repo) might be possible.
    • Since Adam optimizer is momentum-based, the effect of deviation of amax on a per-batch basis is more bounded.

4. Adaptive Precision

Has it been explored using lower precision (FP8) at high learning rate (at earlier epochs) and higher precision (e.g. FP32, FP16) at lower learning rate (at later epochs)?

No. This approach requires preserving enough memory in earlier epochs to store high-precision weights in later stages, which may not be as efficient as using high-precision weights and low-bit computations.