Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoke tests with tiny gpt2, fix CCSReporter #149

Merged
merged 10 commits into from
Mar 24, 2023
Prev Previous commit
Next Next commit
undo changes
  • Loading branch information
thejaminator committed Mar 24, 2023
commit c124d40da1ec2b1b4ea0a444eed75c12dca29bd1
31 changes: 14 additions & 17 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Functions for extracting the hidden states of a model."""

import logging
from .prompt_dataset import Prompt, PromptDataset, PromptConfig
from ..utils import (
assert_type,
float32_to_int16,
infer_label_column,
select_train_val_splits,
select_usable_devices,
)
from .generator import _GeneratorBuilder
from dataclasses import dataclass, InitVar
from typing import Iterable, Literal, Union

import torch
from datasets import (
Array3D,
DatasetDict,
Expand All @@ -15,7 +20,6 @@
SplitDict,
SplitInfo,
Value,
Dataset,
)
from simple_parsing.helpers import field, Serializable
from transformers import (
Expand All @@ -25,16 +29,9 @@
BatchEncoding,
PreTrainedModel,
)

from .generator import _GeneratorBuilder
from .prompt_dataset import Prompt, PromptDataset, PromptConfig
from ..utils import (
assert_type,
float32_to_int16,
infer_label_column,
select_train_val_splits,
select_usable_devices,
)
from typing import Iterable, Literal, Union
import logging
import torch


@dataclass
Expand Down Expand Up @@ -275,7 +272,7 @@ def get_splits() -> SplitDict:
),
}
devices = select_usable_devices(max_gpus)
builders: dict[Split, _GeneratorBuilder] = {
builders = {
split_name: _GeneratorBuilder(
cache_dir=None,
features=Features({**layer_cols, **other_cols}),
Expand All @@ -293,7 +290,7 @@ def get_splits() -> SplitDict:
for (split_name, split_info) in splits.items()
}

ds: dict[Split, Union[Dataset, DatasetDict]] = dict()
ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(num_proc=len(devices))
ds[split] = builder.as_dataset(split=split)
Expand Down