Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort outputs by variant id during extraction #122

Merged
merged 9 commits into from
Mar 17, 2023
Prev Previous commit
only preserve order for the all variants case
  • Loading branch information
AlexTMallen committed Mar 17, 2023
commit 0f78b9ad96ca4f7904d6ca6fb4252c669c9e74c4
8 changes: 1 addition & 7 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]:
# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
for prompts in prompt_ds:
variant_ids = [prompt.template_name for prompt in prompts]

# Sort the variants and prompts by their ID to standardize the order
sorted_idxs = sorted(range(len(variant_ids)), key=variant_ids.__getitem__)
variant_ids = [variant_ids[i] for i in sorted_idxs]
prompts = [prompts[i] for i in sorted_idxs]

inputs = collate(prompts)
hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
Expand All @@ -181,6 +174,7 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]:
)
for layer_idx in layer_indices
}
variant_ids = [prompt.template_name for prompt in prompts]

# Iterate over variants
for i, variant_inputs in enumerate(inputs):
Expand Down
6 changes: 4 additions & 2 deletions elk/extraction/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ def __init__(
def __getitem__(self, index: int) -> list[Prompt]:
"""Get a list of prompts for a given predicate"""
# get self.num_variants unique prompts from the template pool
template_names = self.rng.sample(
list(self.prompter.templates), self.num_variants
template_names = (
self.rng.sample(list(self.prompter.templates), self.num_variants)
if self.num_variants < len(self.prompter.templates)
else list(self.prompter.templates)
)

example = self.active_split[index]
Expand Down