diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5389f1b8..4cfbdb86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,4 +24,4 @@ repos: hooks: - id: codespell # The promptsource templates spuriously get flagged without this - args: ["-L fpr", "--skip=*.yaml"] + args: ["-L fpr,leace", "--skip=*.yaml"] diff --git a/elk/__init__.py b/elk/__init__.py index e3e05a56..ce69da0d 100644 --- a/elk/__init__.py +++ b/elk/__init__.py @@ -1,10 +1,10 @@ from .extraction import Extract, extract_hiddens -from .training import EigenReporter, EigenReporterConfig +from .training import EigenFitter, EigenFitterConfig from .truncated_eigh import truncated_eigh __all__ = [ - "EigenReporter", - "EigenReporterConfig", + "EigenFitter", + "EigenFitterConfig", "extract_hiddens", "Extract", "truncated_eigh", diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 034bde14..d6054e33 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -9,7 +9,6 @@ from ..files import elk_reporter_dir from ..metrics import evaluate_preds from ..run import Run -from ..training import Reporter from ..utils import Color @@ -40,8 +39,7 @@ def apply_to_layer( experiment_dir = elk_reporter_dir() / self.source reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" - reporter = Reporter.load(reporter_path, map_location=device) - reporter.eval() + reporter = torch.load(reporter_path, map_location=device) row_bufs = defaultdict(list) for ds_name, (val_h, val_gt, _) in val_output.items(): diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 31fb63a3..54a47b22 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,16 +1,15 @@ -from .ccs_reporter import CcsReporter, CcsReporterConfig +from .ccs_reporter import CcsConfig, CcsReporter from .classifier import Classifier -from .concept_eraser import ConceptEraser -from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import Reporter, ReporterConfig +from .common import FitterConfig +from .eigen_reporter import EigenFitter, EigenFitterConfig +from .platt_scaling import PlattMixin __all__ = [ "CcsReporter", - "CcsReporterConfig", + "CcsConfig", "Classifier", - "ConceptEraser", - "EigenReporter", - "EigenReporterConfig", - "Reporter", - "ReporterConfig", + "EigenFitter", + "EigenFitterConfig", + "FitterConfig", + "PlattMixin", ] diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index a7f2121d..cd161dd9 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -3,71 +3,61 @@ import math from copy import deepcopy from dataclasses import dataclass, field -from pathlib import Path from typing import Literal, Optional, cast import torch import torch.nn as nn +from concept_erasure import LeaceFitter from torch import Tensor from ..parsing import parse_loss from ..utils.typing import assert_type -from .concept_eraser import ConceptEraser +from .common import FitterConfig from .losses import LOSSES -from .reporter import Reporter, ReporterConfig +from .platt_scaling import PlattMixin @dataclass -class CcsReporterConfig(ReporterConfig): - """ - Args: - activation: The activation function to use. Defaults to GELU. - bias: Whether to use a bias term in the linear layers. Defaults to True. - hidden_size: The number of hidden units in the MLP. Defaults to None. - By default, use an MLP expansion ratio of 4/3. This ratio is used by - Tucker et al. (2022) in their 3-layer - MLP probes. We could also use a ratio of 4, imitating transformer FFNs, - but this seems to lead to excessively large MLPs when num_layers > 2. - init: The initialization scheme to use. Defaults to "zero". - loss: The loss function to use. list of strings, each of the form - "coef*name", where coef is a float and name is one of the keys in - `elk.training.losses.LOSSES`. - Example: --loss 1.0*consistency_squared 0.5*prompt_var - corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. - Defaults to the loss "ccs_squared_loss". - normalization: The kind of normalization to apply to the hidden states. - num_layers: The number of layers in the MLP. Defaults to 1. - pre_ln: Whether to include a LayerNorm module before the first linear - layer. Defaults to False. - supervised_weight: The weight of the supervised loss. Defaults to 0.0. - - lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`. - Defaults to 1e-2. - num_epochs: The number of epochs to train for. Defaults to 1000. - num_tries: The number of times to try training the reporter. Defaults to 10. - optimizer: The optimizer to use. Defaults to "adam". - weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01. - """ - +class CcsConfig(FitterConfig): activation: Literal["gelu", "relu", "swish"] = "gelu" + """The activation function to use.""" bias: bool = True + """Whether to use a bias term in the linear layers.""" hidden_size: Optional[int] = None + """ + The number of hidden units in the MLP. Defaults to None. By default, use an MLP + expansion ratio of 4/3. This ratio is used by Tucker et al. (2022) + in their 3-layer MLP probes. We could also use + a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively + large MLPs when num_layers > 2. + """ init: Literal["default", "pca", "spherical", "zero"] = "default" + """The initialization scheme to use.""" loss: list[str] = field(default_factory=lambda: ["ccs"]) + """ + The loss function to use. list of strings, each of the form "coef*name", where coef + is a float and name is one of the keys in `elk.training.losses.LOSSES`. + Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss + function 1.0*consistency_squared + 0.5*prompt_var. + """ loss_dict: dict[str, float] = field(default_factory=dict, init=False) num_layers: int = 1 + """The number of layers in the MLP.""" pre_ln: bool = False + """Whether to include a LayerNorm module before the first linear layer.""" supervised_weight: float = 0.0 + """The weight of the supervised loss.""" lr: float = 1e-2 + """The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.""" num_epochs: int = 1000 + """The number of epochs to train for.""" num_tries: int = 10 + """The number of times to try training the reporter.""" optimizer: Literal["adam", "lbfgs"] = "lbfgs" + """The optimizer to use.""" weight_decay: float = 0.01 - - @classmethod - def reporter_class(cls) -> type[Reporter]: - return CcsReporter + """The weight decay or L2 penalty to use.""" def __post_init__(self): self.loss_dict = parse_loss(self.loss) @@ -76,7 +66,7 @@ def __post_init__(self): self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()] -class CcsReporter(Reporter): +class CcsReporter(nn.Module, PlattMixin): """CCS reporter network. Args: @@ -84,11 +74,11 @@ class CcsReporter(Reporter): cfg: The reporter configuration. """ - config: CcsReporterConfig + config: CcsConfig def __init__( self, - cfg: CcsReporterConfig, + cfg: CcsConfig, in_features: int, *, device: str | torch.device | None = None, @@ -106,12 +96,7 @@ def __init__( hidden_size = cfg.hidden_size or 4 * in_features // 3 - self.norm = ConceptEraser( - in_features, - 2 * num_variants, - device=device, - dtype=dtype, - ) + self.norm = None self.probe = nn.Sequential( nn.Linear( in_features, @@ -175,6 +160,8 @@ def reset_parameters(self): def forward(self, x: Tensor) -> Tensor: """Return the credence assigned to the hidden state `x`.""" + assert self.norm is not None, "Must call fit() before forward()" + raw_scores = self.probe(self.norm(x)).squeeze(-1) return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) @@ -203,19 +190,22 @@ def fit(self, hiddens: Tensor) -> float: x_neg, x_pos = hiddens.unbind(2) # One-hot indicators for each prompt template - n, v, _ = x_neg.shape + n, v, d = x_neg.shape prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1) - self.norm.update( + fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device) + fitter.update( x=x_neg, # Independent indicator for each (template, pseudo-label) pair - y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), + z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1), ) - self.norm.update( + fitter.update( x=x_pos, # Independent indicator for each (template, pseudo-label) pair - y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), + z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1), ) + self.norm = fitter.eraser + x_neg, x_pos = self.norm(x_neg), self.norm(x_pos) # Record the best acc, loss, and params found so far @@ -299,9 +289,3 @@ def closure(): optimizer.step(closure) return float(loss) - - def save(self, path: Path | str) -> None: - """Save the reporter to a file.""" - state = {k: v.cpu() for k, v in self.state_dict().items()} - state.update(in_features=self.in_features, num_variants=self.num_variants) - torch.save(state, path) diff --git a/elk/training/common.py b/elk/training/common.py new file mode 100644 index 00000000..d93ff006 --- /dev/null +++ b/elk/training/common.py @@ -0,0 +1,31 @@ +"""An ELK reporter network.""" + +from dataclasses import dataclass + +from concept_erasure import LeaceEraser +from simple_parsing.helpers import Serializable +from torch import Tensor, nn + +from .platt_scaling import PlattMixin + + +@dataclass +class FitterConfig(Serializable, decode_into_subclasses=True): + seed: int = 42 + """The random seed to use.""" + + +@dataclass +class Reporter(PlattMixin): + weight: Tensor + eraser: LeaceEraser + + def __post_init__(self): + # Platt scaling parameters + self.bias = nn.Parameter(self.weight.new_zeros(1)) + self.scale = nn.Parameter(self.weight.new_ones(1)) + + def __call__(self, hiddens: Tensor) -> Tensor: + """Return the predicted log odds on input `x`.""" + raw_scores = self.eraser(hiddens) @ self.weight.mT + return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) diff --git a/elk/training/concept_eraser.py b/elk/training/concept_eraser.py deleted file mode 100644 index 3432dd2f..00000000 --- a/elk/training/concept_eraser.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -from torch import Tensor, nn - - -class ConceptEraser(nn.Module): - """Removes the subspace responsible for correlations between hiddens and labels.""" - - mean_x: Tensor - """Running mean of X.""" - - mean_y: Tensor - """Running mean of Y.""" - - xcov_M2: Tensor - """Unnormalized cross-covariance matrix X^T Y.""" - - n: Tensor - """Number of samples seen so far.""" - - def __init__( - self, - x_dim: int, - y_dim: int, - *, - batch_dims: tuple[int, ...] = (), - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - rank: int | None = None, - ): - super().__init__() - - self.batch_dims = batch_dims - self.y_dim = y_dim - self.x_dim = x_dim - self.rank = rank or y_dim - - self.register_buffer( - "mean_x", torch.zeros(*batch_dims, x_dim, device=device, dtype=dtype) - ) - self.register_buffer("mean_y", self.mean_x.new_zeros(*batch_dims, y_dim)) - self.register_buffer( - "xcov_M2", - self.mean_x.new_zeros(*batch_dims, x_dim, y_dim), - ) - self.register_buffer("n", torch.tensor(0, device=device, dtype=dtype)) - - def forward(self, x: Tensor) -> Tensor: - """Remove the subspace responsible for correlations between x and y.""" - *_, d, _ = self.xcov_M2.shape - assert self.n > 0, "Call update() before forward()" - assert x.shape[-1] == d - - # First center the input - x_ = x - self.mean_x - - # Remove the subspace. We treat x_ as a batch of (1 x d) vectors - proj = (x_[..., None, :] @ self.u) @ self.u.mT - x_ -= proj.squeeze(-2) - - return x_ - - @torch.no_grad() - def update(self, x: Tensor, y: Tensor) -> "ConceptEraser": - """Update the running statistics with a new batch of data.""" - *_, d, c = self.xcov_M2.shape - - # Flatten everything before the batch_dims - x = x.reshape(-1, *self.batch_dims, d).type_as(self.mean_x) - - n, *_, d2 = x.shape - assert d == d2, f"Unexpected number of features {d2}" - - # y might start out 1D, but we want to treat it as 2D - y = y.reshape(n, *self.batch_dims, -1).type_as(x) - assert y.shape[-1] == c, f"Unexpected number of classes {y.shape[-1]}" - - self.n += n - - # Welford's online algorithm - delta_x = x - self.mean_x - self.mean_x += delta_x.sum(dim=0) / self.n - - delta_y = y - self.mean_y - self.mean_y += delta_y.sum(dim=0) / self.n - delta_y2 = y - self.mean_y - - self.xcov_M2 += torch.einsum("b...m,b...n->...mn", delta_x, delta_y2) - return self - - @property - def u(self) -> Tensor: - """Orthonormal basis for the subspace to remove.""" - if self.y_dim == self.rank: - # When we're entirely erasing the subspace, we can use QR instead of SVD to - # get an orthonormal basis for the column space of the xcov matrix - u, _ = torch.linalg.qr(self.xcov) - else: - # We only want to erase the highest energy part of the subspace - u, _, _ = torch.svd_lowrank(self.xcov, q=self.rank) - - return u - - @property - def P(self) -> Tensor: - """Projection matrix for removing the subspace.""" - u = self.u - eye = torch.eye(self.x_dim, device=u.device, dtype=u.dtype) - return eye - u @ u.mT - - @property - def xcov(self) -> Tensor: - """The cross-covariance matrix.""" - return self.xcov_M2 / self.n diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 48e3c1ee..a3525b1d 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -1,21 +1,19 @@ """An ELK reporter network.""" from dataclasses import dataclass -from pathlib import Path import torch +from concept_erasure import LeaceFitter from einops import rearrange -from torch import Tensor, nn +from torch import Tensor -from ..truncated_eigh import truncated_eigh from ..utils.math_util import cov_mean_fused -from .concept_eraser import ConceptEraser -from .reporter import Reporter, ReporterConfig +from .common import FitterConfig, Reporter @dataclass -class EigenReporterConfig(ReporterConfig): - """Configuration for an EigenReporter.""" +class EigenFitterConfig(FitterConfig): + """Configuration for an EigenFitter.""" var_weight: float = 0.0 """The weight of the variance term in the loss.""" @@ -27,7 +25,7 @@ class EigenReporterConfig(ReporterConfig): """The number of eigenvectors to compute from the VINC matrix.""" save_reporter_stats: bool = False - """Whether to save the reporter statistics to disk in EigenReporter.save(). This + """Whether to save the reporter statistics to disk in EigenFitter.save(). This is useful for debugging and analysis, but can take up a lot of disk space.""" erase_prompts: bool = False @@ -42,13 +40,9 @@ def __post_init__(self): if self.num_heads <= 0: raise ValueError("num_heads must be positive") - @classmethod - def reporter_class(cls) -> type[Reporter]: - return EigenReporter - -class EigenReporter(Reporter): - """A linear reporter whose weights are computed via eigendecomposition. +class EigenFitter: + """Fit a linear reporter with eigendecomposition. Args: cfg: The reporter configuration. @@ -69,19 +63,19 @@ class EigenReporter(Reporter): the columns are sorted in descending order of eigenvalue magnitude. """ - config: EigenReporterConfig + config: EigenFitterConfig intercluster_cov_M2: Tensor # variance intracluster_cov: Tensor # invariance contrastive_xcov_M2: Tensor # negative covariance n: Tensor - class_means: Tensor | None + class_means: Tensor weight: Tensor def __init__( self, - cfg: EigenReporterConfig, + cfg: EigenFitterConfig, in_features: int, num_classes: int = 2, *, @@ -95,10 +89,7 @@ def __init__( self.num_classes = num_classes self.num_variants = num_variants - # Learnable Platt scaling parameters - self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) - self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) - self.norm = ConceptEraser( + self.leace = LeaceFitter( in_features, num_classes * num_variants if cfg.erase_prompts else num_classes, device=device, @@ -106,48 +97,20 @@ def __init__( ) # Running statistics - self.register_buffer( - "n", - torch.zeros((), device=device, dtype=torch.long), - persistent=cfg.save_reporter_stats, - ) - self.register_buffer( - "class_means", - ( - torch.zeros(num_classes, in_features, device=device, dtype=dtype) - if num_classes is not None - else None - ), - persistent=cfg.save_reporter_stats, - ) - - self.register_buffer( - "contrastive_xcov_M2", - torch.zeros(in_features, in_features, device=device, dtype=dtype), - persistent=cfg.save_reporter_stats, + self.n = torch.zeros((), device=device, dtype=torch.long) + self.class_means = torch.zeros( + num_classes, in_features, device=device, dtype=dtype ) - self.register_buffer( - "intercluster_cov_M2", - torch.zeros(in_features, in_features, device=device, dtype=dtype), - persistent=cfg.save_reporter_stats, + self.contrastive_xcov_M2 = torch.zeros( + in_features, in_features, device=device, dtype=dtype ) - self.register_buffer( - "intracluster_cov", - torch.zeros(in_features, in_features, device=device, dtype=dtype), - persistent=cfg.save_reporter_stats, + self.intercluster_cov_M2 = torch.zeros( + in_features, in_features, device=device, dtype=dtype ) - - # Reporter weights - self.register_buffer( - "weight", - torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), + self.intracluster_cov = torch.zeros( + in_features, in_features, device=device, dtype=dtype ) - def forward(self, hiddens: Tensor) -> Tensor: - """Return the predicted log odds on input `x`.""" - raw_scores = self.norm(hiddens) @ self.weight.mT - return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) - @property def contrastive_xcov(self) -> Tensor: assert self.n > 0, "Stats not initialized; did you set save_reporter_stats?" @@ -184,12 +147,12 @@ def update(self, hiddens: Tensor) -> None: if self.config.erase_prompts: # Independent indicator for each (template, pseudo-label) pair indicators = torch.eye(k * v, device=hiddens.device).expand(n, -1, -1) - self.norm.update(x=hiddens, y=indicators) + self.leace.update(x=hiddens, z=indicators) else: # Only use indicators for each pseudo-label indicators = torch.eye(k, device=hiddens.device).expand(n, v, -1, -1) - self.norm.update(x=hiddens, y=indicators) + self.leace.update(x=hiddens, z=indicators) # *** Invariance (intra-cluster) *** # This is just a standard online *mean* update, since we're computing the @@ -208,16 +171,12 @@ def update(self, hiddens: Tensor) -> None: # Iterating over classes for i, h in enumerate(centroids.unbind(1)): - # Update the running means if needed - if self.class_means is not None: - delta = h - self.class_means[i] - self.class_means[i] += delta.sum(dim=0) / self.n + # Update the running means + delta = h - self.class_means[i] + self.class_means[i] += delta.sum(dim=0) / self.n - # Post-mean update deltas are used to update the (co)variance - delta2 = h - self.class_means[i] # [n, d] - else: - delta = h - h.mean(dim=0) - delta2 = delta + # Post-mean update deltas are used to update the (co)variance + delta2 = h - self.class_means[i] # [n, d] # *** Variance (inter-cluster) *** # See code at https://bit.ly/3YC9BhH and "Welford's online algorithm" @@ -237,7 +196,7 @@ def update(self, hiddens: Tensor) -> None: scale = 1 / (k * (k - 1)) self.contrastive_xcov_M2.addmm_(d.mT, d_, alpha=scale) - def fit_streaming(self, truncated: bool = False) -> float: + def fit_streaming(self) -> Reporter: """Fit the probe using the current streaming statistics.""" inv_weight = 1 - self.config.neg_cov_weight A = ( @@ -247,54 +206,32 @@ def fit_streaming(self, truncated: bool = False) -> float: ) # Remove the subspace responsible for pseudolabel correlations - A = self.norm.P @ A @ self.norm.P.mT - - if truncated: - L, Q = truncated_eigh(A, k=self.config.num_heads, seed=self.config.seed) - else: + A = self.leace.eraser.P @ A @ self.leace.eraser.P.mT + try: + L, Q = torch.linalg.eigh(A) + except torch.linalg.LinAlgError: try: - L, Q = torch.linalg.eigh(A) - except torch.linalg.LinAlgError: - try: - L, Q = torch.linalg.eig(A) - L, Q = L.real, Q.real - except torch.linalg.LinAlgError as e: - # Check if the matrix has non-finite values - if not A.isfinite().all(): - raise ValueError( - "Fitting the reporter failed because the VINC matrix has " - "non-finite entries. Usually this means the hidden states " - "themselves had non-finite values." - ) from e - else: - raise e - - L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] - - self.weight.data = Q.T - return -float(L[-1]) - - def fit(self, hiddens: Tensor) -> float: + L, Q = torch.linalg.eig(A) + L, Q = L.real, Q.real + except torch.linalg.LinAlgError as e: + # Check if the matrix has non-finite values + if not A.isfinite().all(): + raise ValueError( + "Fitting the reporter failed because the VINC matrix has " + "non-finite entries. Usually this means the hidden states " + "themselves had non-finite values." + ) from e + else: + raise e + + L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :] + return Reporter(Q.T, self.leace.eraser) + + def fit(self, hiddens: Tensor) -> Reporter: """Fit the probe to the contrast set `hiddens`. Args: hiddens: The contrast set of shape [batch, variants, choices, dim]. - - Returns: - loss: Negative eigenvalue associated with the VINC direction. """ self.update(hiddens) return self.fit_streaming() - - def save(self, path: Path | str) -> None: - """Save the reporter to a file.""" - # We basically never want to instantiate the reporter on the same device - # it happened to be trained on, so we save the state dict as CPU tensors. - # Bizarrely, this also seems to save a LOT of disk space in some cases. - state = {k: v.cpu() for k, v in self.state_dict().items()} - state.update( - in_features=self.in_features, - num_classes=self.num_classes, - num_variants=self.num_variants, - ) - torch.save(state, path) diff --git a/elk/training/platt_scaling.py b/elk/training/platt_scaling.py new file mode 100644 index 00000000..278d8d95 --- /dev/null +++ b/elk/training/platt_scaling.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch import Tensor, nn, optim + + +class PlattMixin(ABC): + """Mixin for classifier-like objects that can be Platt scaled.""" + + bias: nn.Parameter + scale: nn.Parameter + + @abstractmethod + def __call__(self, *args: Any, **kwds: Any) -> Any: + ... + + def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): + """Fit the scale and bias terms to data with LBFGS. + + Args: + labels: Binary labels of shape [batch]. + hiddens: Hidden states of shape [batch, dim]. + max_iter: Maximum number of iterations for LBFGS. + """ + opt = optim.LBFGS( + [self.bias, self.scale], + line_search_fn="strong_wolfe", + max_iter=max_iter, + tolerance_change=torch.finfo(hiddens.dtype).eps, + tolerance_grad=torch.finfo(hiddens.dtype).eps, + ) + + def closure(): + opt.zero_grad() + loss = nn.functional.binary_cross_entropy_with_logits( + self(hiddens), labels.float() + ) + + loss.backward() + return float(loss) + + opt.step(closure) diff --git a/elk/training/reporter.py b/elk/training/reporter.py deleted file mode 100644 index 1372d329..00000000 --- a/elk/training/reporter.py +++ /dev/null @@ -1,99 +0,0 @@ -"""An ELK reporter network.""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - -import torch -from simple_parsing.helpers import Serializable -from simple_parsing.helpers.serialization import load -from torch import Tensor, nn, optim - - -@dataclass -class ReporterConfig(ABC, Serializable, decode_into_subclasses=True): - """ - Args: - seed: The random seed to use. Defaults to 42. - """ - - seed: int = 42 - - @classmethod - @abstractmethod - def reporter_class(cls) -> type["Reporter"]: - """Get the reporter class associated with this config.""" - - -class Reporter(nn.Module, ABC): - """An ELK reporter network.""" - - # Learned Platt scaling parameters - bias: nn.Parameter - scale: nn.Parameter - - def reset_parameters(self): - """Reset the parameters of the probe.""" - - @abstractmethod - def fit( - self, - hiddens: Tensor, - labels: Optional[Tensor] = None, - ) -> float: - ... - - @classmethod - def load(cls, path: Path | str, *, map_location: str = "cpu"): - """Load a reporter from a file.""" - obj = torch.load(path, map_location=map_location) - if isinstance(obj, Reporter): # Backwards compatibility - return obj - - # Loading a state dict rather than the full object - elif isinstance(obj, dict): - cls_path = Path(path).parent / "cfg.yaml" - cfg = load(ReporterConfig, cls_path) - - # Non-tensor values get passed to the constructor as kwargs - kwargs = {} - special_keys = {k for k, v in obj.items() if not isinstance(v, Tensor)} - for k in special_keys: - kwargs[k] = obj.pop(k) - - reporter_cls = cfg.reporter_class() - reporter = reporter_cls(cfg, device=map_location, **kwargs) - reporter.load_state_dict(obj) - return reporter - else: - raise TypeError( - f"Expected a `dict` or `Reporter` object, but got {type(obj)}." - ) - - def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): - """Fit the scale and bias terms to data with LBFGS. - - Args: - labels: Binary labels of shape [batch]. - hiddens: Hidden states of shape [batch, dim]. - max_iter: Maximum number of iterations for LBFGS. - """ - opt = optim.LBFGS( - [self.bias, self.scale], - line_search_fn="strong_wolfe", - max_iter=max_iter, - tolerance_change=torch.finfo(hiddens.dtype).eps, - tolerance_grad=torch.finfo(hiddens.dtype).eps, - ) - - def closure(): - opt.zero_grad() - loss = nn.functional.binary_cross_entropy_with_logits( - self(hiddens), labels.float() - ) - - loss.backward() - return float(loss) - - opt.step(closure) diff --git a/elk/training/sweep.py b/elk/training/sweep.py index 60b9ec58..e4aca5a0 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -9,7 +9,7 @@ from ..extraction import Extract from ..files import memorably_named_dir, sweeps_dir from ..plotting.visualize import visualize_sweep -from ..training.eigen_reporter import EigenReporterConfig +from ..training.eigen_reporter import EigenFitterConfig from ..utils import colorize from ..utils.constants import BURNS_DATASETS from .train import Elicit @@ -66,9 +66,9 @@ def __post_init__(self, add_pooled: bool): raise ValueError("No models specified") # can only use hparam_step if we're using an eigen net if self.hparam_step > 0 and not isinstance( - self.run_template.net, EigenReporterConfig + self.run_template.net, EigenFitterConfig ): - raise ValueError("Can only use hparam_step with EigenReporterConfig") + raise ValueError("Can only use hparam_step with EigenFitterConfig") elif self.hparam_step > 1: raise ValueError("hparam_step must be in [0, 1]") @@ -136,7 +136,7 @@ def execute(self): ) run = replace(self.run_template, data=data, out_dir=out_dir) if var_weight is not None and neg_cov_weight is not None: - assert isinstance(run.net, EigenReporterConfig) + assert isinstance(run.net, EigenFitterConfig) run.net.var_weight = var_weight run.net.neg_cov_weight = neg_cov_weight diff --git a/elk/training/train.py b/elk/training/train.py index 88325b5b..8392f2d9 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -15,17 +15,17 @@ from ..run import Run from ..training.supervised import train_supervised from ..utils.typing import assert_type -from .ccs_reporter import CcsReporter, CcsReporterConfig -from .eigen_reporter import EigenReporter, EigenReporterConfig -from .reporter import ReporterConfig +from .ccs_reporter import CcsConfig, CcsReporter +from .common import FitterConfig +from .eigen_reporter import EigenFitter, EigenFitterConfig @dataclass class Elicit(Run): """Full specification of a reporter training run.""" - net: ReporterConfig = subgroups( - {"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen" + net: FitterConfig = subgroups( + {"ccs": CcsConfig, "eigen": EigenFitterConfig}, default="eigen" ) """Config for building the reporter network.""" @@ -75,7 +75,9 @@ def apply_to_layer( raise ValueError("All datasets must have the same number of classes") reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) - if isinstance(self.net, CcsReporterConfig): + train_loss = None + + if isinstance(self.net, CcsConfig): assert len(train_dict) == 1, "CCS only supports single-task training" reporter = CcsReporter(self.net, d, device=device, num_variants=v) @@ -87,8 +89,8 @@ def apply_to_layer( rearrange(first_train_h, "n v k d -> (n v k) d"), ) - elif isinstance(self.net, EigenReporterConfig): - reporter = EigenReporter( + elif isinstance(self.net, EigenFitterConfig): + fitter = EigenFitter( self.net, d, num_classes=k, num_variants=v, device=device ) @@ -102,9 +104,9 @@ def apply_to_layer( label_list.append( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten() ) - reporter.update(train_h) + fitter.update(train_h) - train_loss = reporter.fit_streaming() + reporter = fitter.fit_streaming() reporter.platt_scale( torch.cat(label_list), torch.cat(hidden_list), @@ -113,7 +115,7 @@ def apply_to_layer( raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk - reporter.save(reporter_dir / f"layer_{layer}.pt") + torch.save(reporter, reporter_dir / f"layer_{layer}.pt") # Fit supervised logistic regression model if self.supervised != "none": diff --git a/pyproject.toml b/pyproject.toml index 1b4dcc45..0abef5bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,9 @@ license = { text = "MIT License" } dependencies = [ # Allows us to use device_map in from_pretrained. Also needed for 8bit "accelerate", + # For pseudolabel and prompt normalization. We're picky about the version because + # the package isn't guaranteed to be stable yet. + "concept-erasure==0.1.0", # Added distributed.split_dataset_by_node for IterableDatasets "datasets>=2.9.0", "einops", diff --git a/tests/test_concept_eraser.py b/tests/test_concept_eraser.py deleted file mode 100644 index ec5519c1..00000000 --- a/tests/test_concept_eraser.py +++ /dev/null @@ -1,96 +0,0 @@ -import numpy as np -import pytest -import torch -from sklearn.datasets import make_classification -from sklearn.linear_model import LogisticRegression - -from elk.metrics import to_one_hot -from elk.training import ConceptEraser - - -@pytest.mark.parametrize("batch_dims", [(), (2,), (3, 4)]) -def test_stats(batch_dims: tuple[int, ...]): - num_features = 3 - num_classes = 2 - batch_size = 10 - num_batches = 5 - - # Initialize the ConceptEraser - eraser = ConceptEraser(num_features, num_classes, batch_dims=batch_dims) - - # Generate random data - torch.manual_seed(42) - x_data = [ - torch.randn(batch_size, *batch_dims, num_features) for _ in range(num_batches) - ] - y_data = [ - torch.randint(0, num_classes, (batch_size, *batch_dims, num_classes)) - for _ in range(num_batches) - ] - - # Compute cross-covariance matrix using batched updates - for x, y in zip(x_data, y_data): - eraser.update(x, y) - - # Compute the expected cross-covariance matrix using the whole dataset - x_all = torch.cat(x_data) - y_all = torch.cat(y_data) - mean_x = x_all.mean(dim=0) - mean_y = y_all.type_as(x_all).mean(dim=0) - x_centered = x_all - mean_x - y_centered = y_all - mean_y - expected_xcov = torch.einsum("b...m,b...n->...mn", x_centered, y_centered) - expected_xcov /= batch_size * num_batches - - # Compare the computed cross-covariance matrix with the expected one - torch.testing.assert_close(eraser.xcov, expected_xcov) - - -# Both `1` and `2` are binary classification problems, but `1` means the labels are -# encoded in a 1D one-hot vector, while `2` means the labels are encoded in an -# n x 2 one-hot matrix. -@pytest.mark.parametrize("num_classes", [1, 2, 3, 5, 10, 20]) -def test_projection(num_classes: int): - n, d = 2048, 128 - num_distinct = max(num_classes, 2) - - X, Y = make_classification( - n_samples=n, - n_features=d, - n_classes=num_distinct, - n_informative=num_distinct, - random_state=42, - ) - X_t = torch.from_numpy(X) - Y_t = torch.from_numpy(Y) - if num_classes > 1: - Y_t = to_one_hot(Y_t, num_classes) - - eraser = ConceptEraser(d, num_classes, dtype=torch.float64).update(X_t, Y_t) - X_ = eraser(X_t) - - # Heuristic threshold for singular values taken from torch.linalg.pinv - eps = max(n, d) * torch.finfo(X_.dtype).eps - - # Check that the rank of the update is num_classes + 1 - # The +1 comes from subtracting the mean before projection - rank = torch.linalg.svdvals(X_t - X_).gt(eps).sum().float() - torch.testing.assert_close(rank, torch.tensor(num_classes + 1.0)) - - # Compute class means and check that they are equal after the projection - class_means_ = [X_.numpy()[Y == c].mean(axis=0) for c in range(num_distinct)] - np.testing.assert_almost_equal(class_means_[1:], class_means_[:-1]) - - # Sanity check that class means are NOT equal before the projection - class_means = [X[Y == c].mean(axis=0) for c in range(num_distinct)] - assert not np.allclose(class_means[1:], class_means[:-1]) - - # Logistic regression should not be able to learn anything - null_lr = LogisticRegression(max_iter=1000, tol=0.0).fit(X_.numpy(), Y) - beta = torch.from_numpy(null_lr.coef_) - assert beta.norm(p=torch.inf) < eps - - # Sanity check that it DOES learn something before the projection - real_lr = LogisticRegression(max_iter=1000).fit(X, Y) - beta = torch.from_numpy(real_lr.coef_) - assert beta.norm(p=torch.inf) > 0.1 diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py index be9ebbca..6303cb03 100644 --- a/tests/test_eigen_reporter.py +++ b/tests/test_eigen_reporter.py @@ -1,6 +1,6 @@ import torch -from elk.training import EigenReporter, EigenReporterConfig +from elk.training import EigenFitter, EigenFitterConfig from elk.utils import batch_cov, cov_mean_fused @@ -13,8 +13,8 @@ def test_eigen_reporter(): x1, x2 = x.chunk(2, dim=0) x_neg, x_pos = x.unbind(2) - reporter = EigenReporter( - EigenReporterConfig(), + reporter = EigenFitter( + EigenFitterConfig(), hidden_size, dtype=torch.float64, num_variants=num_clusters, diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index 7cf0e8c9..bac0f398 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -1,7 +1,7 @@ from pathlib import Path from elk import Extract -from elk.training import CcsReporterConfig, EigenReporterConfig +from elk.training import CcsConfig, EigenFitterConfig from elk.training.train import Elicit @@ -18,7 +18,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): ), num_gpus=2, min_gpu_mem=min_mem, - net=CcsReporterConfig(), + net=CcsConfig(), out_dir=tmp_path, ) elicit.execute() @@ -49,7 +49,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): ), num_gpus=2, min_gpu_mem=min_mem, - net=EigenReporterConfig(), + net=EigenFitterConfig(), out_dir=tmp_path, ) elicit.execute() diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index d58db6cd..4efd7112 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -4,7 +4,7 @@ from elk import Extract from elk.evaluation import Eval -from elk.training import CcsReporterConfig, EigenReporterConfig +from elk.training import CcsConfig, EigenFitterConfig from elk.training.train import Elicit EVAL_EXPECTED_FILES = [ @@ -34,7 +34,7 @@ def setup_elicit( ), num_gpus=2, min_gpu_mem=min_mem, - net=CcsReporterConfig() if is_ccs else EigenReporterConfig(), + net=CcsConfig() if is_ccs else EigenFitterConfig(), out_dir=tmp_path, ) elicit.execute()