Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple datasets refactor #189

Merged
merged 17 commits into from
Apr 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
All tests working ostensibly
  • Loading branch information
norabelrose committed Apr 13, 2023
commit a6c382e2deaadaa45dc6cee212e96c0c481657c7
32 changes: 14 additions & 18 deletions tests/test_load_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import cycle, islice
from itertools import islice
from typing import Literal

import pytest
Expand All @@ -10,26 +10,22 @@
@pytest.mark.filterwarnings("ignore:Unable to find a decoding function")
def test_load_prompts():
def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
prompt_ds = load_prompts(
*cfg.datasets,
split_type=split_type,
)
prompters = []

for ds in cfg.datasets:
ds_name, _, config_name = ds.partition(" ")
for cfg in cfg.explode():
ds_string = cfg.datasets[0]
prompt_ds = load_prompts(ds_string, split_type=split_type)

ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name or None)
prompters.append(prompter)

limit = cfg.max_examples[0 if split_type == "train" else 1]
for prompter, record in zip(cycle(prompters), islice(prompt_ds, limit)):
true_template_names = prompter.all_template_names
returned_template_names = record["template_names"]
limit = cfg.max_examples[0 if split_type == "train" else 1]
for record in islice(prompt_ds, limit):
true_template_names = prompter.all_template_names
returned_template_names = record["template_names"]

# check for using the same templates
assert set(true_template_names) == set(returned_template_names)
# check for them being in the same order
assert true_template_names == true_template_names
# check for using the same templates
assert set(true_template_names) == set(returned_template_names)
# check for them being in the same order
assert true_template_names == true_template_names

# the case where the dataset has 2 classes
# this dataset is small
Expand Down