From bdbd7ce4a0345cb5a5884b63659df5f43da3bc1d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 18 Apr 2023 21:31:23 +0000 Subject: [PATCH] Actually use BalancedSampler again --- elk/extraction/prompt_loading.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index 97d41c80..4fa7ee3d 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -20,7 +20,7 @@ infer_num_classes, select_train_val_splits, ) -from .balanced_sampler import FewShotSampler +from .balanced_sampler import BalancedSampler, FewShotSampler @dataclass @@ -114,10 +114,10 @@ def load_prompts( rank: int = 0, world_size: int = 1, ) -> Iterator[dict]: - """Load a dataset full of prompts generated from the specified datasets. + """Load a dataset full of prompts generated from the specified dataset. Args: - ds_string: Space-delimited name of the HuggingFace datasets to use, + ds_string: Space-delimited name of the HuggingFace dataset to use, e.g. `"super_glue boolq"` or `"imdb"`. num_shots: The number of examples to use in few-shot prompts. If zero, prompts are zero-shot. @@ -128,7 +128,7 @@ def load_prompts( world_size: The number of processes. Defaults to 1. Returns: - An iterable dataset of prompts. + An iterable of prompt dictionaries. """ ds_name, _, config_name = ds_string.partition(" ") prompter = DatasetTemplates(ds_name, config_name) @@ -178,7 +178,7 @@ def load_prompts( extra_cols = list(assert_type(Features, ds.features)) extra_cols.remove(label_column) - for example in ds: + for example in BalancedSampler(ds, num_classes): yield _convert_to_prompts( example, label_column=label_column,