Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce VINC reporter file size by >1000x #219

Merged
merged 6 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()

row_bufs = defaultdict(list)
Expand Down
13 changes: 8 additions & 5 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.multiprocessing as mp
import yaml
from simple_parsing.helpers import Serializable, field
from simple_parsing.helpers.serialization import save
from torch import Tensor
from tqdm import tqdm

Expand All @@ -37,12 +38,14 @@ class Run(ABC, Serializable):
"""Directory to save results to. If None, a directory will be created
automatically."""

datasets: list[DatasetDictWithName] = field(default_factory=list, init=False)
datasets: list[DatasetDictWithName] = field(
default_factory=list, init=False, to_dict=False
)
"""Datasets containing hidden states and labels for each layer."""

concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None
min_gpu_mem: int | None = None # in bytes
num_gpus: int = -1
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)
Expand Down Expand Up @@ -76,9 +79,9 @@ def execute(
print(f"Output directory at \033[1m{self.out_dir}\033[0m")
self.out_dir.mkdir(parents=True, exist_ok=True)

path = self.out_dir / "cfg.yaml"
with open(path, "w") as f:
self.dump_yaml(f)
# save_dc_types really ought to be the default... We simply can't load
# properly without this flag enabled.
save(self, self.out_dir / "cfg.yaml", save_dc_types=True)

path = self.out_dir / "fingerprints.yaml"
with open(path, "w") as meta_f:
Expand Down
19 changes: 17 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
Expand Down Expand Up @@ -59,7 +60,6 @@ class CcsReporterConfig(ReporterConfig):
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
pre_ln: bool = False
seed: int = 42
supervised_weight: float = 0.0

lr: float = 1e-2
Expand All @@ -68,6 +68,10 @@ class CcsReporterConfig(ReporterConfig):
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)

Expand All @@ -94,6 +98,11 @@ def __init__(
):
super().__init__()
self.config = cfg
self.in_features = in_features

# Learnable Platt scaling parameters
self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype))
self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype))

hidden_size = cfg.hidden_size or 4 * in_features // 3

Expand Down Expand Up @@ -239,7 +248,7 @@ def forward(self, x: Tensor) -> Tensor:

def raw_forward(self, x: Tensor) -> Tensor:
"""Apply the probe to the provided input, without normalization."""
return self.probe(x).squeeze(-1)
return self.probe(x).mul(self.scale).add(self.bias).squeeze(-1)

def loss(
self,
Expand Down Expand Up @@ -401,3 +410,9 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features)
torch.save(state, path)
110 changes: 47 additions & 63 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,47 @@
"""An ELK reporter network."""

from dataclasses import dataclass
from typing import Optional
from pathlib import Path

import torch
from einops import rearrange, repeat
from torch import Tensor, nn, optim
from einops import rearrange
from torch import Tensor, nn

from ..metrics import to_one_hot
from ..truncated_eigh import truncated_eigh
from ..utils.math_util import cov_mean_fused
from .reporter import Reporter, ReporterConfig


@dataclass
class EigenReporterConfig(ReporterConfig):
"""Configuration for an EigenReporter.

Args:
var_weight: The weight of the variance term in the loss.
neg_cov_weight: The weight of the negative covariance term in the loss.
num_heads: The number of reporter heads to fit. In other words, the number
of eigenvectors to compute from the VINC matrix.
"""
"""Configuration for an EigenReporter."""

var_weight: float = 0.0
"""The weight of the variance term in the loss."""

neg_cov_weight: float = 0.5
"""The weight of the negative covariance term in the loss."""

num_heads: int = 1
"""The number of reporter heads to fit."""

save_reporter_stats: bool = False
"""Whether to save the reporter statistics to disk in EigenReporter.save(). This
is useful for debugging and analysis, but can take up a lot of disk space."""

use_centroids: bool = True
"""Whether to average hiddens within each cluster before computing covariance."""

def __post_init__(self):
if not (0 <= self.neg_cov_weight <= 1):
raise ValueError("neg_cov_weight must be in [0, 1]")
if self.num_heads <= 0:
raise ValueError("num_heads must be positive")

@classmethod
def reporter_class(cls) -> type[Reporter]:
return EigenReporter


class EigenReporter(Reporter):
"""A linear reporter whose weights are computed via eigendecomposition.
Expand Down Expand Up @@ -71,6 +77,7 @@ class EigenReporter(Reporter):
intercluster_cov_M2: Tensor # variance
intracluster_cov: Tensor # invariance
contrastive_xcov_M2: Tensor # negative covariance

n: Tensor
class_means: Tensor | None
weight: Tensor
Expand All @@ -79,40 +86,50 @@ def __init__(
self,
cfg: EigenReporterConfig,
in_features: int,
num_classes: int | None = 2,
num_classes: int | None = None,
*,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.config = cfg
self.in_features = in_features
self.num_classes = num_classes

# Learnable Platt scaling parameters
self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype))
self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype))

# Running statistics
self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long))
self.register_buffer(
"n",
torch.zeros((), device=device, dtype=torch.long),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"class_means",
(
torch.zeros(num_classes, in_features, device=device, dtype=dtype)
if num_classes is not None
else None
),
persistent=cfg.save_reporter_stats,
)

self.register_buffer(
"contrastive_xcov_M2",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"intercluster_cov_M2",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)
self.register_buffer(
"intracluster_cov",
torch.zeros(in_features, in_features, device=device, dtype=dtype),
persistent=cfg.save_reporter_stats,
)

# Reporter weights
Expand All @@ -128,10 +145,12 @@ def forward(self, hiddens: Tensor) -> Tensor:

@property
def contrastive_xcov(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return self.contrastive_xcov_M2 / self.n

@property
def intercluster_cov(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return self.intercluster_cov_M2 / self.n

@property
Expand All @@ -140,19 +159,13 @@ def confidence(self) -> Tensor:

@property
def invariance(self) -> Tensor:
assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?"
return -self.weight @ self.intracluster_cov @ self.weight.mT

@property
def consistency(self) -> Tensor:
return -self.weight @ self.contrastive_xcov @ self.weight.mT

def clear(self) -> None:
"""Clear the running statistics of the reporter."""
self.contrastive_xcov_M2.zero_()
self.intracluster_cov.zero_()
self.intercluster_cov_M2.zero_()
self.n.zero_()

@torch.no_grad()
def update(self, hiddens: Tensor) -> None:
(n, _, k, d) = hiddens.shape
Expand Down Expand Up @@ -239,55 +252,26 @@ def fit_streaming(self, truncated: bool = False) -> float:
self.weight.data = Q.T
return -float(L[-1])

def fit(
self,
hiddens: Tensor,
labels: Optional[Tensor] = None,
) -> float:
def fit(self, hiddens: Tensor) -> float:
"""Fit the probe to the contrast set `hiddens`.

Args:
hiddens: The contrast set of shape [batch, variants, choices, dim].
labels: The ground truth labels if available.

Returns:
loss: Negative eigenvalue associated with the VINC direction.
"""
self.update(hiddens)
loss = self.fit_streaming()

if labels is not None:
(_, v, k, _) = hiddens.shape
hiddens = rearrange(hiddens, "n v k d -> (n v k) d")
labels = to_one_hot(repeat(labels, "n -> (n v)", v=v), k).flatten()

self.platt_scale(labels, hiddens)

return loss

def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100):
"""Fit the scale and bias terms to data with LBFGS.

Args:
labels: Binary labels of shape [batch].
hiddens: Hidden states of shape [batch, dim].
max_iter: Maximum number of iterations for LBFGS.
"""
opt = optim.LBFGS(
[self.bias, self.scale],
line_search_fn="strong_wolfe",
max_iter=max_iter,
tolerance_change=torch.finfo(hiddens.dtype).eps,
tolerance_grad=torch.finfo(hiddens.dtype).eps,
return self.fit_streaming()

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
# We basically never want to instantiate the reporter on the same device
# it happened to be trained on, so we save the state dict as CPU tensors.
# Bizarrely, this also seems to save a LOT of disk space in some cases.
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(
in_features=self.in_features,
num_classes=self.num_classes,
)

def closure():
opt.zero_grad()
loss = nn.functional.binary_cross_entropy_with_logits(
self(hiddens), labels.float()
)

loss.backward()
return float(loss)

opt.step(closure)
torch.save(state, path)
Loading