facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

inference in half precision #283

Closed adrienchaton closed 2 years ago

adrienchaton commented 2 years ago

Hi ESM team

Thanks for open-sourcing the ESM2 models, this is great work !

I have seen the reply to this issue https://github.com/facebookresearch/esm/issues/259#issuecomment-1225456012, yet I would like to discuss further if possible please. Half-precision is very tempting for inference as the models grow, to speed-up computation and save-up memory. And I am walking on slippery ground here so please correct me where I am wrong.

If the models have been trained in mixed or half precision (e.g. with Apex), I would imagine it could be possible to use them for inference both in fp32 and fp16. Whereas if they were trained in fp32 I would imagine we may potentially loose accuracy if doing inference in fp16. Of course these are just suppositions.

Now, by default checkpoints are loaded in fp32. If then I cast models with .half() and set input tokens as torch.int32, I observe that predicted logits or output residue embeddings are not allclose, i.e. there is an average absolute difference of around 1e-3 or 1e-4.

What should we trust more ? FP16 or FP32 ? Regarding embedding values, this is quite hard to judge what impact would this difference have for e.g. transfer learning. For logit values, I guess stuffs like zero-shot scoring would not be affected much by half precision but still, I wonder if this gap could be reduced. I.e. is there a better way to run inference in half precision that the basic casting I have described ?

Thanks for your hints !

nikitos9000 commented 2 years ago

Hi @adrienchaton these are valid concerns for sure. The models were trained in AMP, so I guess you can safely inference them in fp32 or amp modes. As for fp16, one of the numerically unstable parts are rotary embeddings, probably you can force them in fp32 and compare the outputs if you want to keep precision while benefiting in speed.

adrienchaton commented 2 years ago

I see, thank you for the details @nikitos9000 For now I am sticking to fp32 inference but I will report if I do more experiments on fp16 inference.

tomsercu commented 2 years ago

Small correction - we train with --fp16 which implements "Mixed Precision" described by Nvidia in 2017. In short

Automatic Mixed Precision or AMP is different in that the unstable parts of forward() are done in fp32.

So given that info, fp16 should be fine since that's how the LMs were trained. However there are some reports here in the issues of larger-than-expected numerical instability in fp16. But

residue embeddings are not allclose, i.e. there is an average absolute difference of around 1e-3 or 1e-4.

That seems like a pretty reasonable numerical error due to fp16. The real litmus test would be if downstream performance is impacted.

adrienchaton commented 2 years ago

Thanks @tomsercu for the details, I see the difference between training pure FP16 and mixed precision (usually e.g. batch norm is kept in FP32).

I agree it is hard to quantify the effect of these numerical deltas and that if they cannot be put down to close to 0 , the only test is downstream performance ...