Skip to content

Commit

Permalink
Merge pull request #293 from EleutherAI/answer-tokenization
Browse files Browse the repository at this point in the history
Answer tokenization
  • Loading branch information
AlexTMallen committed Nov 1, 2023
2 parents 70a3290 + 309428b commit 937e71d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 10 deletions.
4 changes: 4 additions & 0 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None:
training issues.
"""

print(f"Saving debug log to {out_dir}/debug.log")
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s:\n%(message)s",
filename=out_dir / "debug.log",
filemode="w",
)

if len(datasets) == 0:
logging.warning("No datasets found!")

for ds_name, ds in datasets:
logging.info(
"=========================================\n"
Expand Down
23 changes: 20 additions & 3 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def tokenize_dataset(
for example in prompt_ds:
num_variants = len(example["template_names"])

# Check if we've yielded enough examples
# Check if we've appended enough examples
if len(out_records) >= max_examples * num_variants:
break

Expand All @@ -199,12 +199,21 @@ def tokenize_dataset(
assert len(answer_choices) == 2
answer_ids = []
for choice in answer_choices:
a_id = tokenizer.encode(choice, add_special_tokens=False)
a_id = tokenizer.encode(" " + choice, add_special_tokens=False)

# the Llama tokenizer splits off leading spaces
if tokenizer.decode(a_id[0]).strip() == "":
a_id_without_space = tokenizer.encode(
choice, add_special_tokens=False
)
assert a_id_without_space == a_id[1:]
a_id = a_id_without_space

if len(a_id) > 1:
print(
f"WARNING: answer choice '{choice}' is more than one "
"token, LM probabilities will be calculated using the "
"first token only."
f"first token only ({tokenizer.decode(a_id[0])})"
)
answer_ids.append(a_id[0])
else:
Expand Down Expand Up @@ -249,7 +258,15 @@ def tokenize_dataset(
# print an example text to stdout
if len(out_records) == 0:
print(f"Example text: {record_variants[0]['text']}")
neg_id, pos_id = record_variants[0]["answer_ids"]
print(f'\tneg choice token: "{tokenizer.decode(neg_id)}"')
print(f'\tpos choice token: "{tokenizer.decode(pos_id)}"')
out_records.extend(record_variants)
else:
print(
f"WARNING: reached end of dataset {ds_names[0]} before collecting "
f"{max_examples} examples (only got {len(out_records)})."
)

# transpose the list of dicts into a dict of lists
out_records = {k: [d[k] for d in out_records] for k in out_records[0]}
Expand Down
1 change: 0 additions & 1 deletion elk/extraction/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def maybe_unsqueeze(v):
inputs_cuda = pytree_map(
lambda v: maybe_unsqueeze(v.to(device)), input_record
)
# TODO: have model kwargs so we don't have to duplicate kwargs at each row
outputs = model(**inputs_cuda, **model_kwargs)

if callable(closure):
Expand Down
8 changes: 7 additions & 1 deletion elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def load_prompts(
ds = assert_type(Dataset, ds_dict[split_name])
if "row_id" not in ds.column_names:
ds = ds.add_column("row_id", range(len(ds))) # type: ignore
else:
print("Found `row_id` column, using it as the example id")
ds = ds.shuffle(seed=seed)

prompter, using_blank = get_prompter(ds_name, config_name, template_path)
Expand Down Expand Up @@ -96,6 +98,7 @@ def load_prompts(
fewshot_iter = None

if label_column in ds.features and balance:
print(f"Balancing dataset by {label_column}")
ds = BalancedSampler(
ds.to_iterable_dataset(),
set(label_choices),
Expand Down Expand Up @@ -135,7 +138,10 @@ def _convert_to_prompts(

for template in templates:
statement = template.apply(example)
prompt_counter[statement] += 1

choices = template.get_fixed_answer_choices_list()
choices = tuple(choices) if choices is not None else None
prompt_counter[(statement, choices)] += 1

if fewshot_iter is not None:
# Infinite iterator so we don't need to worry about StopIteration
Expand Down
14 changes: 10 additions & 4 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,22 @@ def prepare_data(
for ds_name, ds in self.datasets:
key = select_split(ds, split_type)

split = ds[key].with_format("torch", device=device, dtype=torch.int16)
labels = assert_type(Tensor, split["label"])
hidden_cols = [
col for col in ds[key].column_names if col.startswith("hidden_")
]
split = ds[key].with_format(
"torch", device=device, dtype=torch.int16, columns=hidden_cols
)
# hiddens shape: (num_examples, num_variants, hidden_d)
hiddens = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))
if self.prompt_indices:
hiddens = hiddens[:, self.prompt_indices]

# convert the remaining columns to torch
split = split.with_format("torch", device=device)
labels = assert_type(Tensor, split["label"])
if "lm_log_odds" in split.column_names:
with split.formatted_as("torch", device=device):
lm_preds = assert_type(Tensor, split["lm_log_odds"])
lm_preds = assert_type(Tensor, split["lm_log_odds"])
else:
lm_preds = None

Expand Down
3 changes: 2 additions & 1 deletion tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def map_fn(ex: dict) -> dict:
input_ids = tokenizer(ex["text"], add_special_tokens=True)["input_ids"]
out_record["input_ids"] = [input_ids + suffix_tokens] # type: ignore
answer_ids = [
tokenizer.encode(s, add_special_tokens=False)[0] for s in ["False", "True"]
tokenizer.encode(s, add_special_tokens=False)[0]
for s in [" False", " True"]
]
out_record["answer_ids"] = answer_ids
return out_record
Expand Down

0 comments on commit 937e71d

Please sign in to comment.