Skip to content

Commit

Permalink
Merge branch 'multi-datasets' of github.com:EleutherAI/elk into multi…
Browse files Browse the repository at this point in the history
…-datasets
  • Loading branch information
norabelrose committed Mar 24, 2023
2 parents a858b65 + 9368dc8 commit 5dc2ec6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
13 changes: 8 additions & 5 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def load_prompts(
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])

# Remove everything but the label column
# Remove everything except the label column
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)

Expand Down Expand Up @@ -156,7 +156,7 @@ def load_prompts(
"prompts": Sequence(
Sequence(
{"answer": "string", "text": "string"},
length=num_classes,
length=2, # contrast pair
),
length=num_variants,
),
Expand Down Expand Up @@ -196,15 +196,17 @@ def qa_cat(q: str, a: str) -> str:
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " "
return f"{q}{sep}{a}" if a and not a.isspace() else q

new_label = rng.choice([0, 1]) if num_classes > 2 else example[label_column]

for template in templates:
choices = []

if num_classes > 2:
template, _ = binarize(
template, example[label_column], rng.choice([0, 1]), rng
template = binarize(
template, example[label_column], assert_type(int, new_label), rng
)

for answer_idx in range(num_classes):
for answer_idx in range(2):
fake_example = example.copy()
fake_example[label_column] = answer_idx

Expand All @@ -231,6 +233,7 @@ def qa_cat(q: str, a: str) -> str:
prompts.append(choices)

return dict(
label=new_label,
prompts=prompts,
template_names=prompter.all_template_names,
)
6 changes: 2 additions & 4 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def int16_to_float32(x: torch.Tensor) -> torch.Tensor:
return x.view(torch.float16).type(torch.float32)


def binarize(
template: Template, label: int, new_label: int, rng: Random
) -> tuple[Template, int]:
def binarize(template: Template, label: int, new_label: int, rng: Random) -> Template:
"""Binarize a template with >2 answer choices, returning a new template and label.
Returns:
Expand All @@ -117,4 +115,4 @@ def binarize(
f"{false} ||| {true}" if new_label else f"{true} ||| {false}"
)

return new_template, new_label
return new_template

0 comments on commit 5dc2ec6

Please sign in to comment.