Skip to content

Commit

Permalink
fix refactor bugs, runnable state
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 12, 2023
1 parent 715bba8 commit 0c2f5c4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 1 addition & 5 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Extract(Serializable):
token_loc: Literal["first", "last", "mean"] = "last"
min_gpu_mem: Optional[int] = None
num_gpus: int = -1
combined_prompter_path: Optional[str] = None # if template file does not exist, combine from datasets and save to this path

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand Down Expand Up @@ -103,8 +102,7 @@ def extract_hiddens(
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
combined_prompter_path=cfg.combined_prompter_path
world_size=world_size
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
Expand Down Expand Up @@ -270,8 +268,6 @@ def get_splits() -> SplitDict:
model_cfg = AutoConfig.from_pretrained(cfg.model)
num_variants = cfg.prompts.num_variants

# if combined prompter flag is set, combine prompt templates

# extraneous, remove ?
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def load_prompts(
label_column=label_column,
num_classes=num_classes,
num_variants=num_variants,
prompter=prompter if not combined_prompter else combined_prompter,
prompter=prompter,
rng=rng,
fewshot_iter=fewshot_iter,
)
Expand Down

0 comments on commit 0c2f5c4

Please sign in to comment.