From 361fb9b0ed0e8b42713435018f0b9f2171925cf2 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Wed, 12 Apr 2023 10:11:11 -0700 Subject: [PATCH] Prevent invalidation of the hidden state cache when num_gpus changes (#182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix bug where cached hidden states aren’t used when num_gpus is different * Actually works now --- elk/extraction/extraction.py | 15 +++++++++------ elk/extraction/generator.py | 18 ++++++++++++++++++ elk/training/eigen_reporter.py | 2 +- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 74f4e489..a9abd6a6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,6 +1,7 @@ """Functions for extracting the hidden states of a model.""" import logging import os +from copy import copy from dataclasses import InitVar, dataclass from itertools import islice from typing import Any, Iterable, Literal, Optional @@ -22,7 +23,6 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_outputs import Seq2SeqLMOutput -# import torch.nn.functional as F from ..utils import ( assert_type, convert_span, @@ -87,10 +87,7 @@ def extract_hiddens( rank: int = 0, world_size: int = 1, ) -> Iterable[dict]: - """Run inference on a model with a set of prompts, yielding the hidden states. - - This is a lightweight, functional version of the `Extractor` API. - """ + """Run inference on a model with a set of prompts, yielding the hidden states.""" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence datasets logging messages from all but the first process @@ -301,6 +298,12 @@ def get_splits() -> SplitDict: ) devices = select_usable_devices(num_gpus, min_memory=cfg.min_gpu_mem) + + # Prevent the GPU-related config options from invalidating the cache + _cfg = copy(cfg) + _cfg.min_gpu_mem = None + _cfg.num_gpus = -1 + builders = { split_name: _GeneratorBuilder( cache_dir=None, @@ -309,7 +312,7 @@ def get_splits() -> SplitDict: split_name=split_name, split_info=split_info, gen_kwargs=dict( - cfg=[cfg] * len(devices), + cfg=[_cfg] * len(devices), device=devices, rank=list(range(len(devices))), split_type=[split_name] * len(devices), diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index fb4d03bc..e3cad0e5 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -1,7 +1,9 @@ +from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional import datasets +from datasets import Features from datasets.splits import NamedSplit @@ -11,6 +13,22 @@ class _GeneratorConfig(datasets.BuilderConfig): gen_kwargs: dict[str, Any] = field(default_factory=dict) features: Optional[datasets.Features] = None + def create_config_id( + self, config_kwargs: dict, custom_features: Features | None + ) -> str: + config_kwargs = deepcopy(config_kwargs) + + # By default the values in gen_kwargs are lists of length world_size. We want + # to erase the world_size dimension so that the config id is the same no matter + # how many processes are used. We also remove the explicit device, rank, and + # world_size keys. + config_kwargs["gen_kwargs"] = { + k: v[0] + for k, v in config_kwargs.get("gen_kwargs", {}).items() + if k not in ("device", "rank", "world_size") + } + return super().create_config_id(config_kwargs, custom_features) + @dataclass class _SplitGenerator: diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 1d2f6a98..821891cc 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -180,7 +180,7 @@ def fit_streaming(self) -> float: try: L, Q = truncated_eigh(A, k=self.config.num_heads) - except ConvergenceError: + except (ConvergenceError, RuntimeError): warn( "Truncated eigendecomposition failed to converge. Falling back on " "PyTorch's dense eigensolver."