Skip to content

Commit

Permalink
Load fp32 models in bfloat16 when possible (#231)
Browse files Browse the repository at this point in the history
* Automatically use bfloat16 in some cases

* Use bfloat16 in more cases; sanity check for int8
  • Loading branch information
norabelrose committed May 3, 2023
1 parent b889473 commit 2d88580
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
16 changes: 3 additions & 13 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Color,
assert_type,
colorize,
float32_to_int16,
float_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
Expand Down Expand Up @@ -165,20 +165,10 @@ def extract_hiddens(
ds_names = cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

if cfg.int8:
# Required by `bitsandbytes`
dtype = torch.float16
elif device == "cpu":
dtype = torch.float32
else:
dtype = "auto"

# We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its
# welcome message on every rank
with redirect_stdout(None) if rank != 0 else nullcontext():
model = instantiate_model(
cfg.model, device_map={"": device}, load_in_8bit=cfg.int8, torch_dtype=dtype
)
model = instantiate_model(cfg.model, device=device, load_in_8bit=cfg.int8)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)
Expand Down Expand Up @@ -313,7 +303,7 @@ def extract_hiddens(
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)
hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden)

text_questions.append(variant_questions)

Expand Down
4 changes: 2 additions & 2 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained
from .pretty import Color, colorize
from .tree_utils import pytree_map
from .typing import assert_type, float32_to_int16, int16_to_float32
from .typing import assert_type, float_to_int16, int16_to_float32

__all__ = [
"assert_type",
"batch_cov",
"Color",
"colorize",
"cov_mean_fused",
"float32_to_int16",
"float_to_int16",
"get_columns_all_equal",
"get_layer_indices",
"has_multiple_configs",
Expand Down
38 changes: 37 additions & 1 deletion elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import transformers
from transformers import (
AutoConfig,
Expand All @@ -19,10 +20,45 @@
_AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES


def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel:
def instantiate_model(
model_str: str,
device: str | torch.device = "cpu",
**kwargs,
) -> PreTrainedModel:
"""Instantiate a model string with the appropriate `Auto` class."""
device = torch.device(device)
kwargs["device_map"] = {"": device}

with prevent_name_conflicts():
model_cfg = AutoConfig.from_pretrained(model_str)

# When the torch_dtype is None, this generally means the model is fp32, because
# the config was probably created before the `torch_dtype` field was added.
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)

# Required by `bitsandbytes` to load in 8-bit.
if kwargs.get("load_in_8bit"):
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
# we can't guarantee that there won't be overflow if we downcast to fp16.
if fp32_weights:
raise ValueError("Cannot load in 8-bit if weights are fp32")

kwargs["torch_dtype"] = torch.float16

# CPUs generally don't support anything other than fp32.
elif device.type == "cpu":
kwargs["torch_dtype"] = torch.float32

# If the model is fp32 but bf16 is available, convert to bf16.
# Usually models with fp32 weights were actually trained in bf16, and
# converting them doesn't hurt performance.
elif fp32_weights and torch.cuda.is_bf16_supported():
kwargs["torch_dtype"] = torch.bfloat16
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
else:
kwargs["torch_dtype"] = "auto"

archs = model_cfg.architectures
if not isinstance(archs, list):
return AutoModel.from_pretrained(model_str, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions elk/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def assert_type(typ: Type[T], obj: Any) -> T:
return cast(typ, obj)


def float32_to_int16(x: torch.Tensor) -> torch.Tensor:
"""Converts float32 to float16, then reinterprets as int16."""
def float_to_int16(x: torch.Tensor) -> torch.Tensor:
"""Converts a floating point tensor to float16, then reinterprets as int16."""
downcast = x.type(torch.float16)
if not downcast.isfinite().all():
raise ValueError("Cannot convert to 16 bit: values are not finite")
Expand Down

0 comments on commit 2d88580

Please sign in to comment.