From e1b183456e8e679f5735f1993a65cd158e921aab Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 18 Apr 2023 04:39:25 +0000 Subject: [PATCH] Make load_prompts actually use label_column and num_classes args --- elk/extraction/prompt_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index c5c9dffe..744da86f 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -160,8 +160,8 @@ def load_prompts( if rank == 0: print(f"Using {num_variants} variants of each prompt") - label_column = infer_label_column(ds.features) - num_classes = infer_num_classes(ds.features[label_column]) + label_column = label_column or infer_label_column(ds.features) + num_classes = num_classes or infer_num_classes(ds.features[label_column]) rng = Random(seed) if num_shots > 0: