meta-llama / llama

Inference code for Llama models
Other
56.64k stars 9.59k forks source link

Why RMSNorm has to be performed under fp32 precision instead of fp16 precision #1048

Open xiabingquan opened 9 months ago

xiabingquan commented 9 months ago

Describe the bug

When inferencing with LLaMA-2-7B, I found that the RMSNorm has to be performed under fp32 precision. Otherwise, for example, when RMSNorm is performed under fp16 precision, the generation results are much worse than fp32.

I didn't test larger models such as LLaMA-2-13B or LLaMA-2-70B

There are many other places where operations are performed under fp32, such as https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L102 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L157 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L301 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L494 However, by replacing them with fp16 one by one, I didn't observe the same phenomenon as RMSNorm that the model will perform much worse.

Minimal reproducible example

In RMSNorm, replace the following line https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L76 with output = self._norm(x)

Output

I tested two prompts:

When RMSNorm is performed under fp32, the generation results seem normal, even though there are some repetitions: image

When RMSNorm is performed under fp16, the generation results totally crash: image

Runtime Environment

xiabingquan commented 9 months ago

FYI, I noticed that CLIP also performs LayerNorm under fp32 precision. https://github.com/openai/CLIP/blob/a1d071733d7111c9c014f024669f959182114e33/clip/model.py#L157-L163