forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
extraction.py
360 lines (308 loc) · 12.6 KB
/
extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
"""Functions for extracting the hidden states of a model."""
import logging
import os
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal
import torch
from datasets import (
Array2D,
Array3D,
DatasetDict,
Features,
Sequence,
SplitDict,
SplitInfo,
Value,
get_dataset_config_info,
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput
from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
@dataclass
class Extract(Serializable):
"""
Args:
model: HuggingFace model string identifying the language model to extract
hidden states from.
prompts: The configuration for the prompt prompts.
layers: The layers to extract hidden states from.
layer_stride: Shortcut for setting `layers` to `range(0, num_layers, stride)`.
token_loc: The location of the token to extract hidden states from. Can be
either "first", "last", or "mean". Defaults to "last".
"""
prompts: PromptConfig
model: str = field(positional=True)
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
raise ValueError(
"Cannot use both --layers and --layer-stride. Please use only one."
)
elif layer_stride > 1:
from transformers import AutoConfig, PretrainedConfig
# Look up the model config to get the number of layers
config = assert_type(
PretrainedConfig, AutoConfig.from_pretrained(self.model)
)
self.layers = tuple(range(0, config.num_hidden_layers, layer_stride))
def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
copies = []
for prompt_cfg in self.prompts.explode():
cfg = copy(self)
cfg.prompts = prompt_cfg
copies.append(cfg)
return copies
@torch.no_grad()
def extract_hiddens(
cfg: "Extract",
*,
device: str | torch.device = "cpu",
split_type: Literal["train", "val"] = "train",
rank: int = 0,
world_size: int = 1,
) -> Iterable[dict]:
"""Run inference on a model with a set of prompts, yielding the hidden states."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Silence datasets logging messages from all but the first process
if rank != 0:
logging.disable(logging.CRITICAL)
p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."
prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
combined_template_output_path=cfg.prompts.combined_template_output_path,
) # this dataset is already sharded, but hasn't been truncated to max_examples
model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model, truncation_side="left", verbose=False
)
has_lm_preds = is_autoregressive(model.config)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")
# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size
for example in islice(prompt_ds, max_examples):
num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
num_choices,
model.config.hidden_size,
device=device,
dtype=torch.int16,
)
for layer_idx in layer_indices
}
lm_logits = torch.empty(
num_variants,
num_choices,
device=device,
dtype=torch.float32,
)
text_inputs = []
# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []
# Iterate over answers
for j, choice in enumerate(record):
text = choice["text"]
# TODO: Do something smarter than "rindex" here. Really we want to
# get the span of the answer directly from Jinja, but that doesn't
# seem possible. This approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])
# Only feed question, not the answer, to the encoder for enc-dec models
if model.config.is_encoder_decoder:
# TODO: Maybe make this more generic for complex templates?
text = text[:answer_start].rstrip()
target = choice["answer"]
else:
target = None
# Record the EXACT string we fed to the model
variant_inputs.append(text)
inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
)
# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping").squeeze().tolist()
inputs = inputs.to(device)
# Run the forward pass
outputs = model(**inputs, output_hidden_states=True)
# Compute the log probability of the answer tokens if available
if has_lm_preds:
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_logits[i, j] = log_p.gather(-1, tokens).sum()
elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length
hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
# First element of list is the input embeddings
hiddens = hiddens[1:]
# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]
# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [h[..., -1, :] for h in hiddens]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)
text_inputs.append(variant_inputs)
out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_inputs=text_inputs,
**hidden_dict,
)
if has_lm_preds:
out_record["model_logits"] = lm_logits
yield out_record
# Dataset.from_generator wraps all the arguments in lists, so we unpack them here
def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})
def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
def get_splits() -> SplitDict:
available_splits = assert_type(SplitDict, info.splits)
train_name, val_name = select_train_val_splits(available_splits)
print(
# Cyan color for dataset name
f"\033[36m{info.builder_name}\033[0m: using '{train_name}' for training and"
f" '{val_name}' for validation"
)
limit_list = cfg.prompts.max_examples
return SplitDict(
{
k: SplitInfo(
name=k,
num_examples=min(limit, v.num_examples) * len(cfg.prompts.datasets),
dataset_name=v.dataset_name,
)
for limit, (k, v) in zip(limit_list, available_splits.items())
},
dataset_name=available_splits.dataset_name,
)
model_cfg = AutoConfig.from_pretrained(cfg.model)
# Retrieve info, used to get splits
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)
ds_features = assert_type(Features, info.features)
label_col = (
cfg.prompts.label_columns[0]
if cfg.prompts.label_columns
else infer_label_column(ds_features)
)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
num_variants = len(prompter.templates)
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
shape=(num_variants, num_classes, model_cfg.hidden_size),
)
for layer in cfg.layers or range(model_cfg.num_hidden_layers)
}
other_cols = {
"variant_ids": Sequence(
Value(dtype="string"),
length=num_variants,
),
"label": Value(dtype="int64"),
"text_inputs": Sequence(
Sequence(
Value(dtype="string"),
),
length=num_variants,
),
}
# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)
devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
builders = {
split_name: _GeneratorBuilder(
builder_name=info.builder_name,
config_name=info.config_name,
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
generator=_extraction_worker,
split_name=split_name,
split_info=split_info,
gen_kwargs=dict(
cfg=[cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
world_size=[len(devices)] * len(devices),
),
)
for (split_name, split_info) in get_splits().items()
}
import multiprocess as mp
mp.set_start_method("spawn", force=True) # type: ignore[attr-defined]
ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(num_proc=len(devices))
ds[split] = builder.as_dataset(split=split)
return DatasetDict(ds)