diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index c5c9dffe..744da86f 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -160,8 +160,8 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) + label_column = label_column or infer_label_column(ds.features) + num_classes = num_classes or infer_num_classes(ds.features[label_column]) rng = Random(seed) if num_shots > 0: