huggingface / transformers-bloom-inference

Fast Inference Solutions for BLOOM
Apache License 2.0
560 stars 114 forks source link

Should I use bf16 or fp16? #69

Closed richarddwang closed 1 year ago

richarddwang commented 1 year ago

Since bf16 and fp16 are different schemes, which should I use for bigscience/bloomz, bigscience/bloom? Or loading in bf16 or fp15 produce the same results?

mayank31398 commented 1 year ago

great question. the model was trained in bf16 so, ideally, I would recommend that but I found from some testing that fp16 works better for inference for bloom but this was trained on a single dataset so, even I am not sure.

richarddwang commented 1 year ago

It surprises me that someone can load a bf16-trained model using fp16, since I thought it would result in total nonsense. Anyway, great thanks for your fast reply. Now I know maybe I should stick with bf16.

richarddwang commented 1 year ago

I have conducted a small experiment, which again suggests that we should load checkpoints in their original data types.

import torch
from torch import nn

torch.set_default_dtype(torch.half)
a = nn.Linear(100,100)
torch.save(a.state_dict(), "/tmp/test.bin")

# If we load tensors in a dtype different than the dtype they are saved in, will this make differences? Yes, see below

torch.set_default_dtype(torch.bfloat16)
b = nn.Linear(100,100)
b.load_state_dict(torch.load("/tmp/test.bin"))
print(b.weight.dtype) # torch.bfloat16
print(torch.allclose(a.weight, b.weight.half())) # False
print(torch.max(a.weight-b.weight.half())) # 0.0002

# Is this due to an rare case? No, see below

mask = (a.weight-b.weight.half()) == torch.max(a.weight-b.weight.half())
print(mask.sum()) # 268

# Is this result from the conversion from bf16 to fp16? No, see below

print(a.weight[mask][0], a.weight.dtype, b.weight[mask][0], b.weight.dtype) # -0.0906, torch.float16, -0.0908, torch.bfloat16

# Does loading in the original dtype and then converting to target dtype be safer? No, see below

torch.set_default_dtype(torch.half)
c = nn.Linear(100, 100)
c.load_state_dict(torch.load("/tmp/test.bin"))

assert torch.equal(c.weight, a.weight)
print(torch.max(c.weight.bfloat16().half() - a.weight)) # 0.0002