Skip to content

Commit

Permalink
Support non-fast tokenizers like LLaMA (EleutherAI#202)
Browse files Browse the repository at this point in the history
* Refactor metrics into evaluate_preds

* Fix stupid CCS bug

* Cluster bootstrap for AUROC; boost default sample size

* Cluster bootstrap for accuracy

* Allow for arbitrary hparam selection in sweep

* Don't normalize LM probs twice

* Fix normalization of LM logits

* Support slow tokenizers and enc-dec (properly)

* Bring back use_encoder_states

* More robustness to models with fp32 weights

* Rename 'text_inputs' to 'text_questions'
  • Loading branch information
norabelrose committed Apr 21, 2023
1 parent 7d7c175 commit 4748a2b
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 147 deletions.
6 changes: 3 additions & 3 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None:
)

train_split, val_split = select_train_val_splits(ds)
text_inputs = ds[val_split][0]["text_inputs"]
text_questions = ds[val_split][0]["text_questions"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]

# log the train size and val size
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")

templates_text = f"{len(text_inputs)} templates used:\n"
templates_text = f"{len(text_questions)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_inputs, template_ids):
for (text0, text1), id in zip(text_questions, template_ids):
templates_text += (
f'***---TEMPLATE "{id}"---***\n'
f"{'false' if label else 'true'}:\n"
Expand Down
2 changes: 2 additions & 0 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Eval(Serializable):
out_dir: Path | None = None
skip_supervised: bool = False

disable_cache: bool = field(default=False, to_dict=False)

def execute(self):
transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"

Expand Down
124 changes: 72 additions & 52 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal
from warnings import filterwarnings

import torch
from datasets import (
Array2D,
Array3D,
DatasetDict,
DownloadMode,
Features,
Sequence,
SplitDict,
Expand All @@ -20,17 +22,17 @@
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
Expand Down Expand Up @@ -58,6 +60,7 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
use_encoder_states: bool = False

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand Down Expand Up @@ -85,7 +88,7 @@ def explode(self) -> list["Extract"]:
return copies


@torch.no_grad()
@torch.inference_mode()
def extract_hiddens(
cfg: "Extract",
*,
Expand All @@ -99,12 +102,30 @@ def extract_hiddens(

# Silence datasets logging messages from all but the first process
if rank != 0:
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)

is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False

has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
Expand All @@ -113,17 +134,7 @@ def extract_hiddens(
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model, truncation_side="left", verbose=False
)
has_lm_preds = is_autoregressive(model.config)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
Expand Down Expand Up @@ -155,62 +166,64 @@ def extract_hiddens(
device=device,
dtype=torch.float32,
)
text_inputs = []
text_questions = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []
variant_questions = []

# Iterate over answers
for j, choice in enumerate(record):
text = choice["text"]

# TODO: Do something smarter than "rindex" here. Really we want to
# get the span of the answer directly from Jinja, but that doesn't
# seem possible. This approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])
text = choice["question"]

# Only feed question, not the answer, to the encoder for enc-dec models
if model.config.is_encoder_decoder:
# TODO: Maybe make this more generic for complex templates?
text = text[:answer_start].rstrip()
target = choice["answer"]
else:
target = None
target = choice["answer"] if is_enc_dec else None

# Record the EXACT string we fed to the model
variant_inputs.append(text)
inputs = tokenizer(
# Record the EXACT question we fed to the model
variant_questions.append(text)
encoding = tokenizer(
text,
return_offsets_mapping=True,
add_special_tokens=False,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
)
).to(device)
input_ids = assert_type(Tensor, encoding.input_ids)

if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
else:
encoding2 = tokenizer(
choice["answer"],
add_special_tokens=False,
return_tensors="pt",
).to(device)
answer = assert_type(Tensor, encoding2.input_ids)

input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
input_ids = input_ids[..., -max_len:]

# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping").squeeze().tolist()
inputs = inputs.to(device)
# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
if is_enc_dec:
inputs["labels"] = answer

# Run the forward pass
outputs = model(**inputs, output_hidden_states=True)
with torch.autocast("cuda", enabled=torch.cuda.is_available()):
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
if has_lm_preds:
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
answer_len = answer.shape[-1]

log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1)
tokens = answer[..., None]
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
length = encoding.labels.shape[-1]
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = (
Expand All @@ -235,12 +248,12 @@ def extract_hiddens(
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)

text_inputs.append(variant_inputs)
text_questions.append(variant_questions)

out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_inputs=text_inputs,
text_questions=text_questions,
**hidden_dict,
)
if has_lm_preds:
Expand All @@ -255,7 +268,11 @@ def _extraction_worker(**kwargs):


def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
cfg: "Extract",
*,
disable_cache: bool = False,
num_gpus: int = -1,
min_gpu_mem: int | None = None,
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

Expand Down Expand Up @@ -311,7 +328,7 @@ def get_splits() -> SplitDict:
length=num_variants,
),
"label": Value(dtype="int64"),
"text_inputs": Sequence(
"text_questions": Sequence(
Sequence(
Value(dtype="string"),
),
Expand All @@ -320,7 +337,7 @@ def get_splits() -> SplitDict:
}

# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg):
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
Expand Down Expand Up @@ -352,7 +369,10 @@ def get_splits() -> SplitDict:

ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(num_proc=len(devices))
builder.download_and_prepare(
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
num_proc=len(devices),
)
ds[split] = builder.as_dataset(split=split)

return DatasetDict(ds)
12 changes: 8 additions & 4 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def explode(self) -> list["PromptConfig"]:

def load_prompts(
ds_string: str,
*,
label_column: Optional[str] = None,
num_classes: int = 0,
num_shots: int = 0,
Expand All @@ -119,6 +120,10 @@ def load_prompts(
Args:
ds_string: Space-delimited name of the HuggingFace dataset to use,
e.g. `"super_glue boolq"` or `"imdb"`.
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset.
num_classes: The number of classes in the dataset. If zero, we infer this from
the datatypes of the columns in the dataset.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot.
seed: The seed to use for prompt randomization.
Expand Down Expand Up @@ -230,23 +235,22 @@ def qa_cat(q: str, a: str) -> str:
fake_example[label_column] = answer_idx

q, a = template.apply(fake_example)
text = qa_cat(q, a or string_choices[answer_idx])
prompt_counter[text] += 1
prompt_counter[(q, a)] += 1

if fewshot_iter is not None:
# Infinite iterator so we don't need to worry about StopIteration
fewshot_examples = next(fewshot_iter)
fewshot_texts = [
qa_cat(q, a) for q, a in map(template.apply, fewshot_examples)
]
text = "\n\n".join(fewshot_texts) + "\n\n" + text
q = "\n\n".join(fewshot_texts) + "\n\n" + q

choices.append(
dict(
# Strip whitespace from the answer to make it easier to
# compare with the model's output
answer=a.strip(),
text=text,
question=q,
)
)

Expand Down
5 changes: 0 additions & 5 deletions elk/metrics/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from dataclasses import dataclass, field

import torch
Expand Down Expand Up @@ -68,10 +67,6 @@ def compute(self, p: int = 2) -> CalibrationEstimate:
break

elif not torch.all(freqs * (1 - freqs)):
warnings.warn(
"Calibration error estimate may be unreliable due to insufficient"
" data in some bins."
)
break

# Save the current binning, it's monotonic and may be the best one
Expand Down
7 changes: 6 additions & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ class Run(ABC):

def __post_init__(self):
self.datasets = [
extract(cfg, num_gpus=self.cfg.num_gpus, min_gpu_mem=self.cfg.min_gpu_mem)
extract(
cfg,
disable_cache=self.cfg.disable_cache,
num_gpus=self.cfg.num_gpus,
min_gpu_mem=self.cfg.min_gpu_mem,
)
for cfg in self.cfg.data.explode()
]

Expand Down
14 changes: 13 additions & 1 deletion elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,19 @@ def fit_streaming(self, truncated: bool = False) -> float:
if truncated:
L, Q = truncated_eigh(A, k=self.config.num_heads)
else:
L, Q = torch.linalg.eigh(A)
try:
L, Q = torch.linalg.eigh(A)
except torch.linalg.LinAlgError as e:
# Check if the matrix has non-finite values
if not A.isfinite().all():
raise ValueError(
"Fitting the reporter failed because the VINC matrix has "
"non-finite entries. Usually this means the hidden states "
"themselves had non-finite values."
) from e
else:
raise e

L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :]

self.weight.data = Q.T
Expand Down
Loading

0 comments on commit 4748a2b

Please sign in to comment.