diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index e6432500..f0ca68e6 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -1,8 +1,8 @@ -from os.path import exists from collections import Counter from copy import deepcopy from dataclasses import dataclass from itertools import zip_longest +from os.path import exists from random import Random from typing import Any, Iterator, Literal, Optional @@ -48,7 +48,7 @@ class PromptConfig(Serializable): -1. seed: The seed to use for prompt randomization. Defaults to 42. stream: Whether to stream the dataset from the Internet. Defaults to False. - combined_template_output_path: Path to save a combined template file to, when + combined_template_output_path: Path to save a combined template file to, when applying prompt invariance across multiple datasets. Interpreted as a subpath of `combined_paths` in the templates dir. Defaults to empty string. """ @@ -127,11 +127,11 @@ def combine_templates(self): "Saved to promptsource/templates/combined_templates/" + f"{self.combined_template_output_path}.yaml" ) - + def verify_cols(self, ds_builder, ref_ds_builder) -> bool: - '''Verify that number of features and number of classes for ClassLabel - match the expected values. - ''' + """Verify that number of features and number of classes for ClassLabel + match the expected values. + """ expected_features = len(ref_ds_builder.info.features) expected_classes = ref_ds_builder.info.features["label"].num_classes num_features = len(ds_builder.info.features) @@ -198,15 +198,13 @@ def load_prompts( An iterable of prompt dictionaries. """ ds_name, _, config_name = ds_string.partition(" ") - + prompter = None if combined_template_output_path and exists(combined_template_output_path): - prompter = DatasetTemplates( - "combined_templates", combined_template_output_path - ) + prompter = DatasetTemplates("combined_templates", combined_template_output_path) else: prompter = DatasetTemplates(ds_name, config_name) - + ds_dict = assert_type( dict, load_dataset(ds_name, config_name or None, streaming=stream) )