Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce VINC reporter file size by >1000x #219

Merged
merged 6 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
VINC reporters are now >1000x smaller on disk
  • Loading branch information
norabelrose committed May 2, 2023
commit 5d49eb3d7fdeafc901c64deb06a880ff926bf164
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()

row_bufs = defaultdict(list)
Expand Down
11 changes: 7 additions & 4 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.multiprocessing as mp
import yaml
from simple_parsing.helpers import Serializable, field
from simple_parsing.helpers.serialization import save
from torch import Tensor
from tqdm import tqdm

Expand All @@ -36,7 +37,9 @@ class Run(ABC, Serializable):
"""Directory to save results to. If None, a directory will be created
automatically."""

datasets: list[DatasetDictWithName] = field(default_factory=list, init=False)
datasets: list[DatasetDictWithName] = field(
default_factory=list, init=False, to_dict=False
)
"""Datasets containing hidden states and labels for each layer."""

concatenated_layer_offset: int = 0
Expand Down Expand Up @@ -70,9 +73,9 @@ def execute(self, highlight_color: str = "cyan"):
print(f"Output directory at \033[1m{self.out_dir}\033[0m")
self.out_dir.mkdir(parents=True, exist_ok=True)

path = self.out_dir / "cfg.yaml"
with open(path, "w") as f:
self.dump_yaml(f)
# save_dc_types really ought to be the default... We simply can't load
# properly without this flag enabled.
save(self, self.out_dir / "cfg.yaml", save_dc_types=True)

path = self.out_dir / "fingerprints.yaml"
with open(path, "w") as meta_f:
Expand Down
19 changes: 17 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
Expand Down Expand Up @@ -59,7 +60,6 @@ class CcsReporterConfig(ReporterConfig):
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
pre_ln: bool = False
seed: int = 42
supervised_weight: float = 0.0

lr: float = 1e-2
Expand All @@ -68,6 +68,10 @@ class CcsReporterConfig(ReporterConfig):
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)

Expand All @@ -94,6 +98,11 @@ def __init__(
):
super().__init__()
self.config = cfg
self.in_features = in_features

# Learnable Platt scaling parameters
self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype))
self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype))

hidden_size = cfg.hidden_size or 4 * in_features // 3

Expand Down Expand Up @@ -220,7 +229,7 @@ def reset_parameters(self):

def forward(self, x: Tensor) -> Tensor:
"""Return the raw score output of the probe on `x`."""
return self.probe(x).squeeze(-1)
return self.probe(x).mul(self.scale).add(self.bias).squeeze(-1)

def loss(
self,
Expand Down Expand Up @@ -379,3 +388,9 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features)
torch.save(state, path)
143 changes: 46 additions & 97 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,43 @@

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch
from einops import rearrange, repeat
from torch import Tensor, nn, optim
from einops import rearrange
from torch import Tensor, nn

from ..metrics import to_one_hot
from ..truncated_eigh import truncated_eigh
from ..utils.math_util import cov_mean_fused
from .reporter import Reporter, ReporterConfig


@dataclass
class EigenReporterConfig(ReporterConfig):
"""Configuration for an EigenReporter.
"""Configuration for an EigenReporter."""

Args:
var_weight: The weight of the variance term in the loss.
neg_cov_weight: The weight of the negative covariance term in the loss.
num_heads: The number of reporter heads to fit. In other words, the number
of eigenvectors to compute from the VINC matrix.
"""
var_weight: float = 0.0
"""The weight of the variance term in the loss."""

var_weight: float = 0.1
neg_cov_weight: float = 0.5
"""The weight of the negative covariance term in the loss."""

num_heads: int = 1
"""The number of reporter heads to fit."""

save_reporter_stats: bool = False
"""Whether to save the reporter statistics to disk in EigenReporter.save(). This
is useful for debugging and analysis, but can take up a lot of disk space."""

def __post_init__(self):
if not (0 <= self.neg_cov_weight <= 1):
raise ValueError("neg_cov_weight must be in [0, 1]")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")

@classmethod
def reporter_class(cls) -> type[Reporter]:
return EigenReporter


class EigenReporter(Reporter):
"""A linear reporter whose weights are computed via eigendecomposition.
Expand Down Expand Up @@ -69,9 +71,10 @@ class EigenReporter(Reporter):

config: EigenReporterConfig

intercluster_cov_M2: Tensor | None # variance
intracluster_cov: Tensor | None # invariance
contrastive_xcov_M2: Tensor | None # negative covariance
intercluster_cov_M2: Tensor # variance
intracluster_cov: Tensor # invariance
contrastive_xcov_M2: Tensor # negative covariance

n: Tensor
class_means: Tensor | None
weight: Tensor
Expand All @@ -80,40 +83,50 @@ def __init__(
self,
cfg: EigenReporterConfig,
in_features: int,
num_classes: int | None = 2,
num_classes: int | None = None,
*,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.config = cfg
self.in_features = in_features
self.num_classes = num_classes

# Learnable Platt scaling parameters
self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype))
self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype))

# Running statistics
self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long))
self.register_buffer(
"n",
torch.zeros((), device=device, dtype=torch.long),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"class_means",
(
torch.zeros(num_classes, in_features, device=device, dtype=dtype)
if num_classes is not None
else None
),
persistent=cfg.save_reporter_stats,
)

self.register_buffer(
"contrastive_xcov_M2",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"intercluster_cov_M2",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"intracluster_cov",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)

# Reporter weights
Expand All @@ -129,10 +142,12 @@ def forward(self, hiddens: Tensor) -> Tensor:

@property
def contrastive_xcov(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return self.contrastive_xcov_M2 / self.n

@property
def intercluster_cov(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return self.intercluster_cov_M2 / self.n

@property
Expand All @@ -141,41 +156,15 @@ def confidence(self) -> Tensor:

@property
def invariance(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return -self.weight @ self.intracluster_cov @ self.weight.mT

@property
def consistency(self) -> Tensor:
return -self.weight @ self.contrastive_xcov @ self.weight.mT

def clear(self) -> None:
"""Clear the running statistics of the reporter."""
assert (
self.contrastive_xcov_M2 is not None
and self.intercluster_cov_M2 is not None
and self.intracluster_cov is not None
), "Covariance matrices have been deleted"
self.contrastive_xcov_M2.zero_()
self.intracluster_cov.zero_()
self.intercluster_cov_M2.zero_()
self.n.zero_()

def delete_stats(self) -> None:
"""Delete the running covariance matrices.

This is useful for saving memory when we're done training the reporter.
"""
self.contrastive_xcov_M2 = None
self.intercluster_cov_M2 = None
self.intracluster_cov = None

@torch.no_grad()
def update(self, hiddens: Tensor) -> None:
assert (
self.contrastive_xcov_M2 is not None
and self.intercluster_cov_M2 is not None
and self.intracluster_cov is not None
), "Covariance matrices have been deleted"

(n, _, k, d) = hiddens.shape

# Sanity checks
Expand Down Expand Up @@ -228,11 +217,6 @@ def update(self, hiddens: Tensor) -> None:
def fit_streaming(self, truncated: bool = False) -> float:
"""Fit the probe using the current streaming statistics."""
inv_weight = 1 - self.config.neg_cov_weight
assert (
self.contrastive_xcov_M2 is not None
and self.intercluster_cov_M2 is not None
and self.intracluster_cov is not None
), "Covariance matrices have been deleted"
A = (
self.config.var_weight * self.intercluster_cov
- inv_weight * self.intracluster_cov
Expand Down Expand Up @@ -260,61 +244,26 @@ def fit_streaming(self, truncated: bool = False) -> float:
self.weight.data = Q.T
return -float(L[-1])

def fit(
self,
hiddens: Tensor,
labels: Optional[Tensor] = None,
) -> float:
def fit(self, hiddens: Tensor) -> float:
"""Fit the probe to the contrast set `hiddens`.

Args:
hiddens: The contrast set of shape [batch, variants, choices, dim].
labels: The ground truth labels if available.

Returns:
loss: Negative eigenvalue associated with the VINC direction.
"""
self.update(hiddens)
loss = self.fit_streaming()

if labels is not None:
(_, v, k, _) = hiddens.shape
hiddens = rearrange(hiddens, "n v k d -> (n v k) d")
labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k).flatten()

self.platt_scale(labels, hiddens)

return loss

def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100):
"""Fit the scale and bias terms to data with LBFGS.

Args:
labels: Binary labels of shape [batch].
hiddens: Hidden states of shape [batch, dim].
max_iter: Maximum number of iterations for LBFGS.
"""
opt = optim.LBFGS(
[self.bias, self.scale],
line_search_fn="strong_wolfe",
max_iter=max_iter,
tolerance_change=torch.finfo(hiddens.dtype).eps,
tolerance_grad=torch.finfo(hiddens.dtype).eps,
return self.fit_streaming()

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
# We basically never want to instantiate the reporter on the same device
# it happened to be trained on, so we save the state dict as CPU tensors.
# Bizarrely, this also seems to save a LOT of disk space in some cases.
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(
in_features=self.in_features,
num_classes=self.num_classes,
)

def closure():
opt.zero_grad()
loss = nn.functional.binary_cross_entropy_with_logits(
self(hiddens), labels.float()
)

loss.backward()
return float(loss)

opt.step(closure)

def save(self, path: Path | str):
# TODO: this method will save separate JSON and PT files
if not self.config.save_reporter_stats:
self.delete_stats()
super().save(path)
torch.save(state, path)
Loading