Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 19, 2023
1 parent 54aa710 commit f7a4713
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from os.path import exists
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from itertools import zip_longest
from os.path import exists
from random import Random
from typing import Any, Iterator, Literal, Optional

Expand Down Expand Up @@ -48,7 +48,7 @@ class PromptConfig(Serializable):
-1.
seed: The seed to use for prompt randomization. Defaults to 42.
stream: Whether to stream the dataset from the Internet. Defaults to False.
combined_template_output_path: Path to save a combined template file to, when
combined_template_output_path: Path to save a combined template file to, when
applying prompt invariance across multiple datasets. Interpreted as a
subpath of `combined_paths` in the templates dir. Defaults to empty string.
"""
Expand Down Expand Up @@ -127,11 +127,11 @@ def combine_templates(self):
"Saved to promptsource/templates/combined_templates/"
+ f"{self.combined_template_output_path}.yaml"
)

def verify_cols(self, ds_builder, ref_ds_builder) -> bool:
'''Verify that number of features and number of classes for ClassLabel
match the expected values.
'''
"""Verify that number of features and number of classes for ClassLabel
match the expected values.
"""
expected_features = len(ref_ds_builder.info.features)
expected_classes = ref_ds_builder.info.features["label"].num_classes
num_features = len(ds_builder.info.features)
Expand Down Expand Up @@ -198,15 +198,13 @@ def load_prompts(
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")

prompter = None
if combined_template_output_path and exists(combined_template_output_path):
prompter = DatasetTemplates(
"combined_templates", combined_template_output_path
)
prompter = DatasetTemplates("combined_templates", combined_template_output_path)
else:
prompter = DatasetTemplates(ds_name, config_name)

ds_dict = assert_type(
dict, load_dataset(ds_name, config_name or None, streaming=stream)
)
Expand Down

0 comments on commit f7a4713

Please sign in to comment.