Skip to content

Commit

Permalink
Support multi class again
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 22, 2023
1 parent d304ab3 commit 761c82d
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 84 deletions.
11 changes: 5 additions & 6 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ class PromptConfig(Serializable):
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset; if there is only one column
with a `ClassLabel` datatype, we use that.
num_classes: The number of classes in the dataset. By default, we infer this
from the datatypes of the columns in the dataset; if there is only one
column with a `ClassLabel` datatype, we use the number of classes in that
column.
max_examples: The maximum number of examples to use from the val dataset.
If a single number, use at most that many examples for each split. If a list
of length 2, use the first element for the train split and the second for
Expand All @@ -53,7 +49,6 @@ class PromptConfig(Serializable):
balance: bool = False
data_dir: Optional[str] = None
label_column: Optional[str] = None
num_classes: Optional[int] = None
max_examples: list[int] = field(default_factory=lambda: [750, 250])
num_shots: int = 0
num_variants: int = -1
Expand Down Expand Up @@ -111,7 +106,6 @@ def load_prompts(
for ds, prompter in zip(raw_datasets, prompters):
label_column = infer_label_column(ds.features)
num_classes = infer_num_classes(ds.features[label_column])
assert num_classes == 2

# Remove everything but the label column
extra_cols = list(assert_type(Features, ds.features))
Expand Down Expand Up @@ -183,6 +177,11 @@ def _convert_to_prompts(
for template in templates:
choices = []

if num_classes > 2:
template, _ = binarize(
template, example[label_column], rng.choice([0, 1]), rng
)

for answer_idx in range(num_classes):
fake_example = example.copy()
fake_example[label_column] = answer_idx
Expand Down
10 changes: 10 additions & 0 deletions tests/dbpedia_prompts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
balance: true
datasets:
- "dbpedia_14"
label_column: null
max_examples:
- 5
- 5
num_shots: 0
num_variants: -1
seed: 42
16 changes: 0 additions & 16 deletions tests/distilgpt2_copa_cfg.yaml

This file was deleted.

16 changes: 0 additions & 16 deletions tests/distilgpt2_dbpedia_cfg.yaml

This file was deleted.

11 changes: 11 additions & 0 deletions tests/super_glue_prompts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
balance: true
datasets:
- "super_glue boolq"
- "super_glue copa"
label_column: null
max_examples:
- 5
- 5
num_shots: 0
num_variants: -1
seed: 42
39 changes: 39 additions & 0 deletions tests/test_load_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from elk.extraction import load_prompts, PromptConfig
from elk.promptsource.templates import DatasetTemplates
from itertools import cycle
from typing import Literal
import pytest


@pytest.mark.filterwarnings("ignore:Unable to find a decoding function")
def test_load_prompts():
def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
prompt_ds = load_prompts(
*cfg.datasets, max_examples=cfg.max_examples[0], split_type=split_type
)
prompters = []

for ds in cfg.datasets:
ds_name, _, config_name = ds.partition(" ")
prompter = DatasetTemplates(ds_name, config_name or None)
prompters.append(prompter)

for prompter, record in zip(cycle(prompters), prompt_ds):
true_template_names = prompter.all_template_names
returned_template_names = record["template_names"]

# check for using the same templates
assert set(true_template_names) == set(returned_template_names)
# check for them being in the same order
assert true_template_names == true_template_names

# the case where the dataset has 2 classes
# this dataset is small
cfg = PromptConfig.load_yaml("tests/super_glue_prompts.yaml")
test_single_split(cfg, "train")
test_single_split(cfg, "val")

# the case where the dataset has more than 2 classes
cfg = PromptConfig.load_yaml("tests/dbpedia_prompts.yaml")
test_single_split(cfg, "train")
test_single_split(cfg, "val")
46 changes: 0 additions & 46 deletions tests/test_prompt_dataset.py

This file was deleted.

0 comments on commit 761c82d

Please sign in to comment.