Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into spar_mt_prompt_invar
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristyKoh committed Apr 12, 2023
2 parents b0c0f63 + 361fb9b commit 34bc364
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
15 changes: 9 additions & 6 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for extracting the hidden states of a model."""
import logging
import os
from copy import copy
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal, Optional
Expand All @@ -22,7 +23,6 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput

# import torch.nn.functional as F
from ..utils import (
assert_type,
convert_span,
Expand Down Expand Up @@ -87,10 +87,7 @@ def extract_hiddens(
rank: int = 0,
world_size: int = 1,
) -> Iterable[dict]:
"""Run inference on a model with a set of prompts, yielding the hidden states.
This is a lightweight, functional version of the `Extractor` API.
"""
"""Run inference on a model with a set of prompts, yielding the hidden states."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Silence datasets logging messages from all but the first process
Expand Down Expand Up @@ -304,6 +301,12 @@ def get_splits() -> SplitDict:
)

devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem)

# Prevent the GPU-related config options from invalidating the cache
_cfg = copy(cfg)
_cfg.min_gpu_mem = None
_cfg.num_gpus = -1

builders = {
split_name: _GeneratorBuilder(
cache_dir=None,
Expand All @@ -312,7 +315,7 @@ def get_splits() -> SplitDict:
split_name=split_name,
split_info=split_info,
gen_kwargs=dict(
cfg=[cfg] * len(devices),
cfg=[_cfg] * len(devices),
device=devices,
rank=list(range(len(devices))),
split_type=[split_name] * len(devices),
Expand Down
18 changes: 18 additions & 0 deletions elk/extraction/generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional

import datasets
from datasets import Features
from datasets.splits import NamedSplit


Expand All @@ -11,6 +13,22 @@ class _GeneratorConfig(datasets.BuilderConfig):
gen_kwargs: dict[str, Any] = field(default_factory=dict)
features: Optional[datasets.Features] = None

def create_config_id(
self, config_kwargs: dict, custom_features: Features | None
) -> str:
config_kwargs = deepcopy(config_kwargs)

# By default the values in gen_kwargs are lists of length world_size. We want
# to erase the world_size dimension so that the config id is the same no matter
# how many processes are used. We also remove the explicit device, rank, and
# world_size keys.
config_kwargs["gen_kwargs"] = {
k: v[0]
for k, v in config_kwargs.get("gen_kwargs", {}).items()
if k not in ("device", "rank", "world_size")
}
return super().create_config_id(config_kwargs, custom_features)


@dataclass
class _SplitGenerator:
Expand Down
2 changes: 1 addition & 1 deletion elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def fit_streaming(self) -> float:

try:
L, Q = truncated_eigh(A, k=self.config.num_heads)
except ConvergenceError:
except (ConvergenceError, RuntimeError):
warn(
"Truncated eigendecomposition failed to converge. Falling back on "
"PyTorch's dense eigensolver."
Expand Down

0 comments on commit 34bc364

Please sign in to comment.