When inferencing with LLaMA-2-7B, I found that the RMSNorm has to be performed under fp32 precision. Otherwise, for example, when RMSNorm is performed under fp16 precision, the generation results are much worse than fp32.
I didn't test larger models such as LLaMA-2-13B or LLaMA-2-70B
Describe the bug
When inferencing with LLaMA-2-7B, I found that the RMSNorm has to be performed under fp32 precision. Otherwise, for example, when RMSNorm is performed under fp16 precision, the generation results are much worse than fp32.
There are many other places where operations are performed under fp32, such as https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L102 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L157 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L301 https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L494 However, by replacing them with fp16 one by one, I didn't observe the same phenomenon as RMSNorm that the model will perform much worse.
Minimal reproducible example
In RMSNorm, replace the following line https://github.com/facebookresearch/llama/blob/6796a91789335a31c8309003339fe44e2fd345c2/llama/model.py#L76 with
output = self._norm(x)
Output
I tested two prompts:
When RMSNorm is performed under fp32, the generation results seem normal, even though there are some repetitions:
When RMSNorm is performed under fp16, the generation results totally crash:
Runtime Environment