Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: messy draft for llama #220

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0b159f1
stop using the generation builder
thejaminator Apr 26, 2023
f928d60
fix partial issues
thejaminator Apr 26, 2023
203bf5b
test use accelerate
thejaminator Apr 26, 2023
8058a38
test use accelerate
thejaminator Apr 26, 2023
3c0178c
fix devices hopefully
thejaminator Apr 26, 2023
ec5a436
fix partial
thejaminator Apr 26, 2023
4035ac4
print the accelerate device
thejaminator Apr 26, 2023
b22c0dc
give up on accelerate
thejaminator Apr 26, 2023
02c2f11
add fsdp
thejaminator Apr 26, 2023
b384b8a
save changes
thejaminator Apr 26, 2023
bae6e1e
commit wroking
thejaminator Apr 27, 2023
9c0845d
refactor exception
thejaminator Apr 27, 2023
e5c132f
remove unneeded
thejaminator Apr 27, 2023
a8d161f
commit working dumb method
thejaminator Apr 27, 2023
c4a17f6
set format in main process
thejaminator Apr 27, 2023
8b5d0d3
clone tensor
thejaminator Apr 27, 2023
538b119
change fsdp to disable cpu_offload
thejaminator Apr 27, 2023
808b152
set format to torch in test
thejaminator Apr 27, 2023
8ce7326
fix closure bug not sharing memory
thejaminator Apr 27, 2023
15834e8
more logs
thejaminator Apr 27, 2023
99c1cb9
log the output sent back
thejaminator Apr 27, 2023
08f6fbc
more logs
thejaminator Apr 27, 2023
f509688
shift it back to float32
thejaminator Apr 27, 2023
439cf8f
print loaded closure
thejaminator Apr 27, 2023
ae4e052
add logging of sentinel
thejaminator Apr 27, 2023
6ab5e53
fix deadlock maybe?
thejaminator Apr 27, 2023
ae5eb25
add print for breaking
thejaminator Apr 27, 2023
0a0fc40
more prints
thejaminator Apr 27, 2023
6d7fa08
set low min mem for fsdp
thejaminator Apr 27, 2023
820388f
set low min mem for fsdp
thejaminator Apr 27, 2023
bfb8e12
add counter
thejaminator Apr 27, 2023
943ae48
stop destroying the process group
thejaminator Apr 27, 2023
f3aa91c
re log
thejaminator Apr 27, 2023
e813a64
replicate it by 2
thejaminator Apr 27, 2023
6649635
add assertions
thejaminator Apr 27, 2023
61cfce8
add type of exception
thejaminator Apr 27, 2023
44ec152
try increasing timeout
thejaminator Apr 27, 2023
d24ded8
try out not sending the sentinel
thejaminator Apr 27, 2023
7ffcbaf
fix typo
thejaminator Apr 27, 2023
eaf3f42
log more
thejaminator Apr 27, 2023
ea2e2ff
try waiting
thejaminator Apr 27, 2023
6fbacce
add sleep
thejaminator Apr 27, 2023
ea7694e
make it 5
thejaminator Apr 27, 2023
fdd854c
skip destroying group
thejaminator Apr 27, 2023
6b77cbb
try while true
thejaminator Apr 27, 2023
e5960ef
Revert "try while true"
thejaminator Apr 27, 2023
2e823ef
Revert "skip destroying group"
thejaminator Apr 27, 2023
659309b
Revert "make it 5"
thejaminator Apr 27, 2023
9eac600
Revert "add sleep"
thejaminator Apr 27, 2023
f5b8a53
Revert "try waiting"
thejaminator Apr 27, 2023
e644ec9
Revert "log more"
thejaminator Apr 27, 2023
322c9d8
Revert "fix typo"
thejaminator Apr 27, 2023
a2bae35
Revert "try out not sending the sentinel"
thejaminator Apr 27, 2023
34448dc
set num workeres to 8
thejaminator Apr 27, 2023
835562e
add commit
thejaminator Apr 27, 2023
6cf1c76
fsdp_single rename
thejaminator Apr 27, 2023
f81e6bf
more logs
thejaminator Apr 27, 2023
d50cda7
add range
thejaminator Apr 27, 2023
00de896
disable cpu offload
thejaminator Apr 27, 2023
155e91f
set min memory by dividng
thejaminator Apr 27, 2023
9b42104
add more logs for tests
thejaminator Apr 27, 2023
5accebb
rename tests
thejaminator Apr 27, 2023
4fe807d
use a sentinel class
thejaminator Apr 27, 2023
cfa854b
add log for FSDP
thejaminator Apr 27, 2023
b1af49e
save changes
thejaminator Apr 27, 2023
3804a84
add tol and better test
thejaminator Apr 27, 2023
4d20887
fix fsdp?? with imap??
thejaminator Apr 27, 2023
2cd4401
add assert for outputs
thejaminator Apr 27, 2023
72fc632
add comment
thejaminator Apr 27, 2023
1fca929
check it again
thejaminator Apr 27, 2023
6c9920d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2023
78c5847
fix imap
thejaminator Apr 27, 2023
192385b
fix assertion
thejaminator Apr 27, 2023
edd67de
try second
thejaminator Apr 27, 2023
fceb165
try second
thejaminator Apr 27, 2023
2772300
fix intiialization
thejaminator Apr 27, 2023
0fd5411
remove len hack
thejaminator Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save changes
  • Loading branch information
thejaminator committed Apr 26, 2023
commit b384b8a96662ccfa080d20c026e0a0ebac3c77b7
204 changes: 181 additions & 23 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch import Tensor
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.utils import ModelOutput

from .dataset_name import (
DatasetDictWithName,
Expand All @@ -39,6 +40,7 @@
select_usable_devices,
)
from ..utils.data_utils import flatten_list
from ..utils.fsdp import InferenceServer


@dataclass
Expand Down Expand Up @@ -90,6 +92,10 @@ def explode(self) -> list["Extract"]:
return copies


def identity(x):
return x


def extract_hiddens_list(
cfg: "Extract",
*,
Expand All @@ -112,6 +118,170 @@ def extract_hiddens_list(
)


@torch.inference_mode()
def extract_hiddens_fsdp(
cfg: "Extract",
*,
server: InferenceServer,
split_type: Literal["train", "val"],
) -> Iterable[dict]:
"""Run inference on a model with a set of prompts, yielding the hidden states."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Silence datasets logging messages from all but the first process

p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left", verbose=True)
model = server._model

is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False

has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
if has_lm_preds:
print("Model has language model head, will store predictions.")

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=p_cfg.stream,
)
world_size = 1

# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers

global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples
device = "cpu" # doesn't matter, we're using FSDP

for example in islice(prompt_ds, max_examples):
num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])

hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
num_choices,
server._model.config.hidden_size,
device=device,
dtype=torch.int16,
)
for layer_idx in layer_indices
}
lm_logits = torch.empty(
num_variants,
num_choices,
device=device,
dtype=torch.float32,
)
text_questions = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_questions = []

# Iterate over answers
for j, choice in enumerate(record):
text = choice["question"]

# Only feed question, not the answer, to the encoder for enc-dec models
target = choice["answer"] if is_enc_dec else None

# Record the EXACT question we fed to the model
variant_questions.append(text)
encoding = tokenizer(
text,
# Keep [CLS] and [SEP] for BERT-style models
add_special_tokens=True,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
).to(device)
input_ids = assert_type(Tensor, encoding.input_ids)

if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
else:
encoding2 = tokenizer(
choice["answer"],
# Don't include [CLS] and [SEP] in the answer
add_special_tokens=False,
return_tensors="pt",
).to(device)
answer = assert_type(Tensor, encoding2.input_ids)

input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
cur_len = input_ids.shape[-1]
input_ids = input_ids[..., -min(cur_len, max_len) :]

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
if is_enc_dec:
inputs["labels"] = answer

# make the dict a dataset
input_dataset = Dataset.from_dict(inputs)
# this is dumb but we'll just do it one by one for now
outputs = server.map(dataset=input_dataset, closure=identity)[0]

# Compute the log probability of the answer tokens if available
if has_lm_preds:
answer_len = answer.shape[-1]

log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1)
tokens = answer[..., None]
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = encoding.labels.shape[-1]
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]

# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [h[..., -1, :] for h in hiddens]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)

text_questions.append(variant_questions)

out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_questions=text_questions,
**hidden_dict,
)
if has_lm_preds:
out_record["model_logits"] = lm_logits

yield out_record


@torch.inference_mode()
def extract_hiddens(
cfg: "Extract",
Expand Down Expand Up @@ -343,29 +513,17 @@ def extract_hiddens_with_gpus(

# ctx = mp.get_context("spawn")
first_device = devices[0]
model = instantiate_model(
cfg.model, torch_dtype="auto" if first_device != "cpu" else torch.float32
)
with ThreadPool(len(devices)) as pool:
for split_name in split_names:
thunks: list[Callable[[], list[dict]]] = []
for rank, device in enumerate(devices):
# Create the functions to extract the hidden states
thunk: Callable[[], list[dict]] = partial(
extract_hiddens_list,
model=model,
cfg=cfg,
device=device,
rank=rank,
split_type=split_name, # type: ignore
world_size=len(devices),
)
thunks.append(thunk)
# Now evaluate them in parallel
split_result: list[dict] = flatten_list(
evaluate_with_processes(sequence=thunks, pool=pool)
)
results[split_name] = split_result

server = InferenceServer(model_str=cfg.model)

for split_name in split_names:
hiddens = []
for hidden in extract_hiddens_fsdp(
cfg=cfg, server=server, split_type=split_name
):
hiddens.append(hidden)
split_result: list[dict] = hiddens
results[split_name] = split_result

# Turn the results into a DatasetDict, that has the splits
# and the list of dicts from extract_hiddens_list
Expand Down
38 changes: 22 additions & 16 deletions elk/utils/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,44 @@
from transformers.modeling_outputs import ModelOutput
from typing import Any, Callable, Iterable, Type, cast

from ..multiprocessing import A
from ..utils import instantiate_model, pytree_map, select_usable_devices


@dataclass
class InferenceServer:
"""High-level interface for running inference on a model on multiple GPUs.

This is basically a glorified `multiprocessing.Pool`. The only difference is that
each worker maintains a copy of the model on a dedicated GPU.
"""

model_str: str
num_workers: int = -1

cpu_offload: bool = False
fsdp: bool = False

def __post_init__(self):
def __init__(
self,
model_str: str,
num_workers: int = -1,
cpu_offload: bool = False,
fsdp: bool = False,
):
self.model_str = model_str
self.num_workers = num_workers
self.cpu_offload = cpu_offload
self.fsdp = fsdp
self._current_id = 0
self._process_ctx: mp.ProcessContext | None = None

self._result_queues = []
self._task_queues = []
model = instantiate_model(model_str, torch_dtype="auto")
model.share_memory()
self._model = model
self._start()

@property
def running(self) -> bool:
"""Whether the server is running."""
return self._process_ctx is not None

def start(self) -> None:
def _start(self) -> None:
"""Spin up the workers."""
if self._process_ctx is not None:
raise RuntimeError("The server is already running")
Expand All @@ -56,8 +64,7 @@ def start(self) -> None:
# This ensures that we don't copy the model num_workers times on the CPU and
# run out of RAM for large models
print("Loading model...")
model = instantiate_model(self.model_str, torch_dtype="auto")
model.share_memory()
model = self._model
model_size = sum(p.numel() * p.element_size() for p in model.parameters())

# Determine which GPUs we can use
Expand Down Expand Up @@ -115,25 +122,24 @@ def shutdown(self) -> bool:

# Support use as a context manager, just like mp.Pool
def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()

def map(
self,
closure: Callable[[ModelOutput], Any],
closure: Callable[[ModelOutput], A],
dataset: Dataset,
) -> list:
) -> list[A]:
"""Run inference on the given inputs, running a closure on the outputs."""
return list(self.imap(closure, dataset))

def imap(
self,
closure: Callable[[ModelOutput], None],
closure: Callable[[ModelOutput], A],
dataset: Dataset,
) -> Iterable:
) -> Iterable[A]:
"""Run inference on the given inputs, running a closure on the outputs."""
if self._process_ctx is None:
raise RuntimeError("Can't run inference on a server that isn't running")
Expand Down