Several models that we'd like to evaluate on, like bigscience/mt0-xxl and allenai/unifiedqa-t5-11b, have float32 checkpoints but were actually trained in bfloat16 on TPUs. Because they're float32, we get out of memory errors when trying to run inference on them. This PR automatically detects if a checkpoint is (likely) float32 before downloading it, and sets torch_dtype=torch.bfloat16 iff torch.cuda.is_bf16_supported() is True.
Some older models, like gpt2, have fp32 checkpoints and were just trained in full precision. But it's nearly impossible for an overflow to occur when running these models in bfloat16, since bf16 has a dynamic range almost equal to that of fp32. There is a bit of precision loss, but empirically neural nets are highly robust to this— as long as there aren't any overflows. So this should be fine. We also print a warning when the downcasting does occur. Maybe we should add a flag to turn off this automatic downcasting, but I haven't included it in this PR for simplicity.
Several models that we'd like to evaluate on, like
bigscience/mt0-xxl
andallenai/unifiedqa-t5-11b
, have float32 checkpoints but were actually trained in bfloat16 on TPUs. Because they're float32, we get out of memory errors when trying to run inference on them. This PR automatically detects if a checkpoint is (likely) float32 before downloading it, and setstorch_dtype=torch.bfloat16
ifftorch.cuda.is_bf16_supported()
is True.Some older models, like
gpt2
, have fp32 checkpoints and were just trained in full precision. But it's nearly impossible for an overflow to occur when running these models in bfloat16, since bf16 has a dynamic range almost equal to that of fp32. There is a bit of precision loss, but empirically neural nets are highly robust to this— as long as there aren't any overflows. So this should be fine. We also print a warning when the downcasting does occur. Maybe we should add a flag to turn off this automatic downcasting, but I haven't included it in this PR for simplicity.