Skip to content

Commit

Permalink
Actually use BalancedSampler again
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 18, 2023
1 parent 5ac38e7 commit bdbd7ce
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
infer_num_classes,
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
from .balanced_sampler import BalancedSampler, FewShotSampler


@dataclass
Expand Down Expand Up @@ -114,10 +114,10 @@ def load_prompts(
rank: int = 0,
world_size: int = 1,
) -> Iterator[dict]:
"""Load a dataset full of prompts generated from the specified datasets.
"""Load a dataset full of prompts generated from the specified dataset.
Args:
ds_string: Space-delimited name of the HuggingFace datasets to use,
ds_string: Space-delimited name of the HuggingFace dataset to use,
e.g. `"super_glue boolq"` or `"imdb"`.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot.
Expand All @@ -128,7 +128,7 @@ def load_prompts(
world_size: The number of processes. Defaults to 1.
Returns:
An iterable dataset of prompts.
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")
prompter = DatasetTemplates(ds_name, config_name)
Expand Down Expand Up @@ -178,7 +178,7 @@ def load_prompts(
extra_cols = list(assert_type(Features, ds.features))
extra_cols.remove(label_column)

for example in ds:
for example in BalancedSampler(ds, num_classes):
yield _convert_to_prompts(
example,
label_column=label_column,
Expand Down

0 comments on commit bdbd7ce

Please sign in to comment.