meta-llama / llama

Inference code for Llama models
Other
55.61k stars 9.48k forks source link

model weights dtype change in Llama.build #1032

Open chunniunai220ml opened 7 months ago

chunniunai220ml commented 7 months ago

when i run the inference as readme shows

CUDA_VISIBLE_DEVICES=5,6 \
  torchrun --nproc_per_node 1 example_text_completion.py \
    --ckpt_dir llama-2-7b/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 128 --max_batch_size 4

run the code, then inspect model weight dtype by:

for name, param in model.named_parameters():
    print(name, param.dtype)

Output

https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L106

checkpoint['layers.31.ffn_norm.weight'].dtype

-> torch.bfloat16

https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L120

model.layers.31.ffn_norm.weight
->torch.float16

as the https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/generation.py#L118 shows, it makes dtype change, but why make this dtype change?

as far as i know, torch.float16 = 1 sign bit + 5 bits (exp) + 10bits (mantissa), torch.bfloat16 = 1 sign bit + 8 bits (exp) + 7bits (mantissa) therefore, after bfloat16 -> float16, if it occurs extreme number(eg: exp:0011111), it will cause loss of accuracy

## Runtime Environment - Model: `llama-2-7b` - Using via huggingface?: no - OS: Ubuntu - GPU VRAM: 64G - Number of GPUs: 8 - GPU Make: AMD mi250 **Additional context** Add any other context about the problem or environment here.
subramen commented 7 months ago

CUDA supports float16 which is more efficient. See L:118 where this is set as the default dtype. You can comment that out to load the model as bf16 if you'd like