Skip to content

Commit

Permalink
Merge pull request EleutherAI#160 from EleutherAI/fix-max-examples
Browse files Browse the repository at this point in the history
split max_examples between processes
  • Loading branch information
AlexTMallen committed Mar 31, 2023
2 parents 156b596 + cb6e5a8 commit db3f9d5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 4 additions & 2 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.typing import assert_type
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset
from datasets import IterableDataset, Features
from itertools import cycle
from random import Random
from torch.utils.data import IterableDataset as TorchIterableDataset
Expand Down Expand Up @@ -62,7 +63,8 @@ def __init__(
label_col: Optional[str] = None,
):
self.dataset = dataset
self.label_col = label_col or infer_label_column(dataset.features)
feats = assert_type(Features, dataset.features)
self.label_col = label_col or infer_label_column(feats)
self.num_shots = num_shots
self.rng = rng

Expand Down
10 changes: 9 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,15 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
# print(f"Using {prompt_ds} variants for each dataset")

max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

print(f"Extracting {max_examples} examples from {prompt_ds} on {device}")

for example in islice(BalancedSampler(prompt_ds), max_examples):
num_variants = len(example["prompts"])
hidden_dict = {
Expand Down
6 changes: 5 additions & 1 deletion elk/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ def save_debug_log(ds, out_dir):
filemode="w",
)

_, val_split = select_train_val_splits(ds)
train_split, val_split = select_train_val_splits(ds)
text_inputs = ds[val_split][0]["text_inputs"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]

# log the train size and val size
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")

templates_text = f"{len(text_inputs)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_inputs, template_ids):
Expand Down

0 comments on commit db3f9d5

Please sign in to comment.