EleutherAI / elk

Keeping language models honest by directly eliciting knowledge encoded in their activations.
MIT License
178 stars 33 forks source link

Load fp32 models in bfloat16 when possible #231

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

Several models that we'd like to evaluate on, like bigscience/mt0-xxl and allenai/unifiedqa-t5-11b, have float32 checkpoints but were actually trained in bfloat16 on TPUs. Because they're float32, we get out of memory errors when trying to run inference on them. This PR automatically detects if a checkpoint is (likely) float32 before downloading it, and sets torch_dtype=torch.bfloat16 iff torch.cuda.is_bf16_supported() is True.

Some older models, like gpt2, have fp32 checkpoints and were just trained in full precision. But it's nearly impossible for an overflow to occur when running these models in bfloat16, since bf16 has a dynamic range almost equal to that of fp32. There is a bit of precision loss, but empirically neural nets are highly robust to this— as long as there aren't any overflows. So this should be fine. We also print a warning when the downcasting does occur. Maybe we should add a flag to turn off this automatic downcasting, but I haven't included it in this PR for simplicity.