Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hardcoded llama65 multigpu with accelerate #228

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,22 @@
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
prevent_name_conflicts,
select_split,
select_train_val_splits,
select_usable_devices,
)
from .dataset_name import (
DatasetDictWithName,
parse_dataset_string,
)
from .generator import _GeneratorBuilder
from .llama.device_configs import (
Llama65bDeviceConfig,
instantiate_model_or_llama,
select_devices_or_llama_65b_configs,
)
from .prompt_loading import load_prompts


Expand Down Expand Up @@ -144,11 +147,16 @@ def explode(self) -> list["Extract"]:
def extract_hiddens(
cfg: "Extract",
*,
device: str | torch.device = "cpu",
device_config: str | Llama65bDeviceConfig = "cpu",
split_type: Literal["train", "val"] = "train",
rank: int = 0,
world_size: int = 1,
) -> Iterable[dict]:
device = (
device_config
if not isinstance(device_config, Llama65bDeviceConfig)
else device_config.first_device
)
"""Run inference on a model with a set of prompts, yielding the hidden states."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand All @@ -160,20 +168,10 @@ def extract_hiddens(
ds_names = cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

if cfg.int8:
# Required by `bitsandbytes`
dtype = torch.float16
elif device == "cpu":
dtype = torch.float32
else:
dtype = "auto"

# We use contextlib.redirect_stdout to prevent `bitsandbytes` from printing its
# welcome message on every rank
with redirect_stdout(None) if rank != 0 else nullcontext():
model = instantiate_model(
cfg.model, device_map={"": device}, load_in_8bit=cfg.int8, torch_dtype=dtype
)
model = instantiate_model_or_llama(cfg=cfg, device_config=device_config)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)
Expand Down Expand Up @@ -397,7 +395,9 @@ def extract(
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
info, features = hidden_features(cfg)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
devices: Sequence[str | Llama65bDeviceConfig] = select_devices_or_llama_65b_configs(
model_name=cfg.model, num_gpus=num_gpus, min_memory=min_gpu_mem
)
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)

Expand Down Expand Up @@ -433,7 +433,7 @@ def extract(
),
gen_kwargs=dict(
cfg=[cfg] * len(devices),
device=devices,
device_config=devices,
rank=list(range(len(devices))),
split_type=[ty] * len(devices),
world_size=[len(devices)] * len(devices),
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_config_id(
config_kwargs["gen_kwargs"] = {
k: v[0]
for k, v in config_kwargs.get("gen_kwargs", {}).items()
if k not in ("device", "rank", "world_size")
if k not in ("device_config", "rank", "world_size")
}
return super().create_config_id(config_kwargs, custom_features)

Expand Down
84 changes: 84 additions & 0 deletions elk/extraction/llama/device_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence

import torch
from transformers import PreTrainedModel

from elk.extraction.llama.device_map import get_llama_65b_8bit_device_map
from elk.utils import instantiate_model, select_usable_devices

if TYPE_CHECKING:
from elk import Extract


@dataclass
class Llama65bDeviceConfig:
first_device: str
second_device: str


def select_devices_or_llama_65b_configs(
model_name: str,
num_gpus: int,
min_memory: int | None = None,
) -> Sequence[str | Llama65bDeviceConfig]:
if "llama-65b" not in model_name:
return select_usable_devices(num_gpus, min_memory=min_memory)
else:
print(
"You've selected a llama-65b model, which requires at least two GPUs."
"Each GPU must have at least 40 GiB of memory."
)
print("Note that we will force the model to use 8-bit")
assert num_gpus >= 2, "llama-65b models require at least two GPUs"
# how many pairs of 2 gpus are specified?
num_pairs = num_gpus // 2
print(f"Will create {num_pairs} llama workers ")
forty_gb = 42_949_672_960
devices = select_usable_devices(num_gpus, min_memory=forty_gb)
# split the devices into pairs
configs = []
while len(configs) < num_pairs:
first_device = devices.pop()
second_device = devices.pop()
configs.append(
Llama65bDeviceConfig(
first_device=first_device, second_device=second_device
)
)
print(f"Created {len(configs)} llama workers")

return configs


def instantiate_model_or_llama(
cfg: "Extract", device_config: str | Llama65bDeviceConfig, **kwargs
) -> PreTrainedModel:
is_llama_65b = isinstance(device_config, Llama65bDeviceConfig)
first_device = device_config.first_device if is_llama_65b else device_config
if cfg.int8 or is_llama_65b:
# Required by `bitsandbytes`
dtype = torch.float16
elif device_config == "cpu":
dtype = torch.float32
else:
dtype = "auto"
if not is_llama_65b:
model = instantiate_model(
cfg.model,
device_map={"": first_device},
load_in_8bit=cfg.int8,
torch_dtype=dtype,
**kwargs,
)
else:
model = instantiate_model(
cfg.model,
device_map=get_llama_65b_8bit_device_map(
first_device=first_device, second_device=device_config.second_device
),
load_in_8bit=True,
torch_dtype=dtype,
**kwargs,
)
return model
117 changes: 117 additions & 0 deletions elk/extraction/llama/device_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
from accelerate import infer_auto_device_map, init_empty_weights

from elk.utils import instantiate_model


def get_suggested_map(model_str: str, used_dtype: torch.dtype) -> dict[str, int]:
"""Util function to get the suggested map for a given model string and dtype
Usually doesn't work out of the box, you'll need to manually
change the attention module
to the same device as the lm_head due to the residual connection.
"""
with init_empty_weights():
# you need to first instantiate the model to get the suggested map
model = instantiate_model(model_str, torch_dtype=used_dtype)
suggested_map = infer_auto_device_map(model)
return suggested_map


def get_llama_65b_8bit_device_map(
first_device: str | torch.device, second_device: str | torch.device
) -> dict[str, str | torch.device]:
"""
This assumes that you are using 2 GPUs, with at least 40GB of memory each.
and that you are using 8bit
"""
return {
"model.embed_tokens": first_device,
"model.layers.0": first_device,
"model.layers.1": first_device,
"model.layers.2": first_device,
"model.layers.3": first_device,
"model.layers.4": first_device,
"model.layers.5": first_device,
"model.layers.6": first_device,
"model.layers.7": first_device,
"model.layers.8": first_device,
"model.layers.9": first_device,
"model.layers.10": first_device,
"model.layers.11": first_device,
"model.layers.12": first_device,
"model.layers.13": first_device,
"model.layers.14": first_device,
"model.layers.15": first_device,
"model.layers.16": first_device,
"model.layers.17": first_device,
"model.layers.18": first_device,
"model.layers.19": first_device,
"model.layers.20": first_device,
"model.layers.21": first_device,
"model.layers.22": first_device,
"model.layers.23": first_device,
"model.layers.24": first_device,
"model.layers.25": first_device,
"model.layers.26": first_device,
"model.layers.27.self_attn": first_device,
"model.layers.27.mlp.gate_proj": first_device,
"model.layers.27.mlp.down_proj": first_device,
"model.layers.27.mlp.up_proj": first_device,
"model.layers.27.mlp.act_fn": first_device,
"model.layers.27.input_layernorm": first_device,
"model.layers.27.post_attention_layernorm": first_device,
"model.layers.28": first_device,
"model.layers.29": first_device,
"model.layers.30": first_device,
"model.layers.31": first_device,
"model.layers.32": first_device,
"model.layers.33": first_device,
"model.layers.34": second_device,
"model.layers.35": second_device,
"model.layers.36": second_device,
"model.layers.37": second_device,
"model.layers.38": second_device,
"model.layers.39": second_device,
"model.layers.40": second_device,
"model.layers.41": second_device,
"model.layers.42": second_device,
"model.layers.43": second_device,
"model.layers.44": second_device,
"model.layers.45": second_device,
"model.layers.46": second_device,
"model.layers.47": second_device,
"model.layers.48": second_device,
"model.layers.49": second_device,
"model.layers.50": second_device,
"model.layers.51": second_device,
"model.layers.52": second_device,
"model.layers.53": second_device,
"model.layers.54": second_device,
"model.layers.55": second_device,
"model.layers.56": second_device,
"model.layers.57": second_device,
"model.layers.58": second_device,
"model.layers.59": second_device,
"model.layers.60": second_device,
"model.layers.61": second_device,
"model.layers.62": second_device,
"model.layers.63": second_device,
"model.layers.64": second_device,
"model.layers.65": second_device,
"model.layers.66": second_device,
"model.layers.67": second_device,
"model.layers.68": second_device,
"model.layers.69": second_device,
"model.layers.70": second_device,
"model.layers.71": second_device,
"model.layers.72": second_device,
"model.layers.73": second_device,
"model.layers.74": second_device,
"model.layers.75": second_device,
"model.layers.76": second_device,
"model.layers.77": second_device,
"model.layers.78": second_device,
"model.layers.79": second_device,
"model.norm": second_device,
"lm_head": first_device,
}
Loading