Closed adrienchaton closed 2 years ago
Hi @adrienchaton these are valid concerns for sure. The models were trained in AMP, so I guess you can safely inference them in fp32 or amp modes. As for fp16, one of the numerically unstable parts are rotary embeddings, probably you can force them in fp32 and compare the outputs if you want to keep precision while benefiting in speed.
I see, thank you for the details @nikitos9000 For now I am sticking to fp32 inference but I will report if I do more experiments on fp16 inference.
Small correction - we train with --fp16
which implements "Mixed Precision" described by Nvidia in 2017. In short
Automatic Mixed Precision or AMP is different in that the unstable parts of forward()
are done in fp32.
So given that info, fp16 should be fine since that's how the LMs were trained. However there are some reports here in the issues of larger-than-expected numerical instability in fp16. But
residue embeddings are not allclose, i.e. there is an average absolute difference of around 1e-3 or 1e-4.
That seems like a pretty reasonable numerical error due to fp16. The real litmus test would be if downstream performance is impacted.
Thanks @tomsercu for the details, I see the difference between training pure FP16 and mixed precision (usually e.g. batch norm is kept in FP32).
I agree it is hard to quantify the effect of these numerical deltas and that if they cannot be put down to close to 0 , the only test is downstream performance ...
Hi ESM team
Thanks for open-sourcing the ESM2 models, this is great work !
I have seen the reply to this issue https://github.com/facebookresearch/esm/issues/259#issuecomment-1225456012, yet I would like to discuss further if possible please. Half-precision is very tempting for inference as the models grow, to speed-up computation and save-up memory. And I am walking on slippery ground here so please correct me where I am wrong.
If the models have been trained in mixed or half precision (e.g. with Apex), I would imagine it could be possible to use them for inference both in fp32 and fp16. Whereas if they were trained in fp32 I would imagine we may potentially loose accuracy if doing inference in fp16. Of course these are just suppositions.
Now, by default checkpoints are loaded in fp32. If then I cast models with .half() and set input tokens as torch.int32, I observe that predicted logits or output residue embeddings are not allclose, i.e. there is an average absolute difference of around 1e-3 or 1e-4.
What should we trust more ? FP16 or FP32 ? Regarding embedding values, this is quite hard to judge what impact would this difference have for e.g. transfer learning. For logit values, I guess stuffs like zero-shot scoring would not be affected much by half precision but still, I wonder if this gap could be reduced. I.e. is there a better way to run inference in half precision that the basic casting I have described ?
Thanks for your hints !