Load fp32 models in bfloat16 when possible #231
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Several models that we'd like to evaluate on, like
bigscience/mt0-xxl
andallenai/unifiedqa-t5-11b
, have float32 checkpoints but were actually trained in bfloat16 on TPUs. Because they're float32, we get out of memory errors when trying to run inference on them. This PR automatically detects if a checkpoint is (likely) float32 before downloading it, and setstorch_dtype=torch.bfloat16
ifftorch.cuda.is_bf16_supported()
is True.Some older models, like
gpt2
, have fp32 checkpoints and were just trained in full precision. But it's nearly impossible for an overflow to occur when running these models in bfloat16, since bf16 has a dynamic range almost equal to that of fp32. There is a bit of precision loss, but empirically neural nets are highly robust to this— as long as there aren't any overflows. So this should be fine. We also print a warning when the downcasting does occur. Maybe we should add a flag to turn off this automatic downcasting, but I haven't included it in this PR for simplicity.