diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index a4102554..88687e0c 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -115,7 +115,7 @@ def combine_templates(self): if i == 0: # Set first dataset as reference ref_ds_builder = ds_builder - elif not self.verify_cols(ds_builder, ref_ds_builder): + elif not self.verify_cols(ref_ds_builder, ds_builder, ds_name): return # Once verified, merge templates. @@ -128,7 +128,7 @@ def combine_templates(self): + f"{self.combined_template_output_path}.yaml" ) - def verify_cols(self, ds_builder, ref_ds_builder) -> bool: + def verify_cols(self, ref_ds_builder, ds_builder, ds_name) -> bool: """Verify that number of features and number of classes for ClassLabel match the expected values. """ @@ -136,7 +136,6 @@ def verify_cols(self, ds_builder, ref_ds_builder) -> bool: expected_classes = ref_ds_builder.info.features["label"].num_classes num_features = len(ds_builder.info.features) num_classes = ds_builder.info.features["label"].num_classes - ds_name = ds_builder.builder_name if expected_features > 0 and num_features != expected_features: print( "WARNING: Datasets do not have the same number of features;",