Skip to content

Commit

Permalink
fix ds_name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 20, 2023
1 parent 3b0765d commit 5f0f32a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def combine_templates(self):
if i == 0:
# Set first dataset as reference
ref_ds_builder = ds_builder
elif not self.verify_cols(ds_builder, ref_ds_builder):
elif not self.verify_cols(ref_ds_builder, ds_builder, ds_name):
return

# Once verified, merge templates.
Expand All @@ -128,15 +128,14 @@ def combine_templates(self):
+ f"{self.combined_template_output_path}.yaml"
)

def verify_cols(self, ds_builder, ref_ds_builder) -> bool:
def verify_cols(self, ref_ds_builder, ds_builder, ds_name) -> bool:
"""Verify that number of features and number of classes for ClassLabel
match the expected values.
"""
expected_features = len(ref_ds_builder.info.features)
expected_classes = ref_ds_builder.info.features["label"].num_classes
num_features = len(ds_builder.info.features)
num_classes = ds_builder.info.features["label"].num_classes
ds_name = ds_builder.builder_name
if expected_features > 0 and num_features != expected_features:
print(
"WARNING: Datasets do not have the same number of features;",
Expand Down

0 comments on commit 5f0f32a

Please sign in to comment.