as far as i know, torch.float16 = 1 sign bit + 5 bits (exp) + 10bits (mantissa), torch.bfloat16 = 1 sign bit + 8 bits (exp) + 7bits (mantissa)
therefore, after bfloat16 -> float16, if it occurs extreme number(eg: exp:0011111), it will cause loss of accuracy
## Runtime Environment
- Model: `llama-2-7b`
- Using via huggingface?: no
- OS: Ubuntu
- GPU VRAM: 64G
- Number of GPUs: 8
- GPU Make: AMD mi250
**Additional context**
Add any other context about the problem or environment here.
CUDA supports float16 which is more efficient. See L:118 where this is set as the default dtype. You can comment that out to load the model as bf16 if you'd like
when i run the inference as readme shows
run the code, then inspect model weight dtype by:
Output
https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L106
checkpoint['layers.31.ffn_norm.weight'].dtype
-> torch.bfloat16
https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L120
model.layers.31.ffn_norm.weight
->torch.float16
as the https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L118 shows, it makes dtype change, but why make this dtype change?
as far as i know, torch.float16 = 1 sign bit + 5 bits (exp) + 10bits (mantissa), torch.bfloat16 = 1 sign bit + 8 bits (exp) + 7bits (mantissa) therefore, after bfloat16 -> float16, if it occurs extreme number(eg: exp:0011111), it will cause loss of accuracy