Skip to content

Commit

Permalink
Merge branch 'main' into not-273-fix-argument-passthrough-for-sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 12, 2023
2 parents 2dbb34f + a88c01a commit 4b1cc50
Show file tree
Hide file tree
Showing 17 changed files with 209 additions and 586 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.275'
rev: 'v0.0.276'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["-L fpr", "--skip=*.yaml"]
args: ["-L fpr,leace", "--skip=*.yaml"]
6 changes: 3 additions & 3 deletions elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .extraction import Extract, extract_hiddens
from .training import EigenReporter, EigenReporterConfig
from .training import EigenFitter, EigenFitterConfig
from .truncated_eigh import truncated_eigh

__all__ = [
"EigenReporter",
"EigenReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
Expand Down
4 changes: 1 addition & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..utils import Color


Expand Down Expand Up @@ -40,8 +39,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()
reporter = torch.load(reporter_path, map_location=device)

row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, _) in val_output.items():
Expand Down
19 changes: 9 additions & 10 deletions elk/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .ccs_reporter import CcsConfig, CcsReporter
from .classifier import Classifier
from .concept_eraser import ConceptEraser
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import Reporter, ReporterConfig
from .common import FitterConfig
from .eigen_reporter import EigenFitter, EigenFitterConfig
from .platt_scaling import PlattMixin

__all__ = [
"CcsReporter",
"CcsReporterConfig",
"CcsConfig",
"Classifier",
"ConceptEraser",
"EigenReporter",
"EigenReporterConfig",
"Reporter",
"ReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"FitterConfig",
"PlattMixin",
]
156 changes: 42 additions & 114 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,61 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
import torch.nn as nn
from concept_erasure import LeaceFitter
from torch import Tensor

from ..metrics import roc_auc
from ..parsing import parse_loss
from ..utils.typing import assert_type
from .classifier import Classifier
from .concept_eraser import ConceptEraser
from .common import FitterConfig
from .losses import LOSSES
from .reporter import Reporter, ReporterConfig
from .platt_scaling import PlattMixin


@dataclass
class CcsReporterConfig(ReporterConfig):
"""
Args:
activation: The activation function to use. Defaults to GELU.
bias: Whether to use a bias term in the linear layers. Defaults to True.
hidden_size: The number of hidden units in the MLP. Defaults to None.
By default, use an MLP expansion ratio of 4/3. This ratio is used by
Tucker et al. (2022) <https://arxiv.org/abs/2204.09722> in their 3-layer
MLP probes. We could also use a ratio of 4, imitating transformer FFNs,
but this seems to lead to excessively large MLPs when num_layers > 2.
init: The initialization scheme to use. Defaults to "zero".
loss: The loss function to use. list of strings, each of the form
"coef*name", where coef is a float and name is one of the keys in
`elk.training.losses.LOSSES`.
Example: --loss 1.0*consistency_squared 0.5*prompt_var
corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var.
Defaults to the loss "ccs_squared_loss".
normalization: The kind of normalization to apply to the hidden states.
num_layers: The number of layers in the MLP. Defaults to 1.
pre_ln: Whether to include a LayerNorm module before the first linear
layer. Defaults to False.
supervised_weight: The weight of the supervised loss. Defaults to 0.0.
lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.
Defaults to 1e-2.
num_epochs: The number of epochs to train for. Defaults to 1000.
num_tries: The number of times to try training the reporter. Defaults to 10.
optimizer: The optimizer to use. Defaults to "adam".
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01.
"""

class CcsConfig(FitterConfig):
activation: Literal["gelu", "relu", "swish"] = "gelu"
"""The activation function to use."""
bias: bool = True
"""Whether to use a bias term in the linear layers."""
hidden_size: Optional[int] = None
"""
The number of hidden units in the MLP. Defaults to None. By default, use an MLP
expansion ratio of 4/3. This ratio is used by Tucker et al. (2022)
<https://arxiv.org/abs/2204.09722> in their 3-layer MLP probes. We could also use
a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively
large MLPs when num_layers > 2.
"""
init: Literal["default", "pca", "spherical", "zero"] = "default"
"""The initialization scheme to use."""
loss: list[str] = field(default_factory=lambda: ["ccs"])
"""
The loss function to use. list of strings, each of the form "coef*name", where coef
is a float and name is one of the keys in `elk.training.losses.LOSSES`.
Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
"""Whether to include a LayerNorm module before the first linear layer."""
supervised_weight: float = 0.0
"""The weight of the supervised loss."""

lr: float = 1e-2
"""The learning rate to use. Ignored when `optimizer` is `"lbfgs"`."""
num_epochs: int = 1000
"""The number of epochs to train for."""
num_tries: int = 10
"""The number of times to try training the reporter."""
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
"""The optimizer to use."""
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter
"""The weight decay or L2 penalty to use."""

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)
Expand All @@ -78,19 +66,19 @@ def __post_init__(self):
self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()]


class CcsReporter(Reporter):
class CcsReporter(nn.Module, PlattMixin):
"""CCS reporter network.
Args:
in_features: The number of input features.
cfg: The reporter configuration.
"""

config: CcsReporterConfig
config: CcsConfig

def __init__(
self,
cfg: CcsReporterConfig,
cfg: CcsConfig,
in_features: int,
*,
device: str | torch.device | None = None,
Expand All @@ -108,12 +96,7 @@ def __init__(

hidden_size = cfg.hidden_size or 4 * in_features // 3

self.norm = ConceptEraser(
in_features,
2 * num_variants,
device=device,
dtype=dtype,
)
self.norm = None
self.probe = nn.Sequential(
nn.Linear(
in_features,
Expand Down Expand Up @@ -142,60 +125,6 @@ def __init__(
)
)

@torch.no_grad()
def check_separability(
self,
train_pair: tuple[Tensor, Tensor],
val_pair: tuple[Tensor, Tensor],
) -> float:
"""Measure how linearly separable the pseudo-labels are for a contrast pair.
Args:
train_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for training the classifier.
val_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the
contrastive representations. Used for evaluating the classifier.
Returns:
The AUROC of a linear classifier fit on the pseudo-labels.
"""
x0, x1 = map(self.norm, train_pair)
val_x0, val_x1 = map(self.norm, val_pair)

pseudo_clf = Classifier(x0.shape[-1], device=x0.device) # type: ignore
pseudo_train = torch.cat(
[
torch.zeros_like(x0[..., 0]),
torch.ones_like(x1[..., 0]),
]
).flatten()
pseudo_val = torch.cat(
[
torch.zeros_like(val_x0[..., 0]),
torch.ones_like(val_x1[..., 0]),
]
).flatten()

pseudo_clf.fit(
# b v d -> (b v) d
torch.cat([x0, x1]).flatten(0, 1),
pseudo_train,
# Use the same weight decay as the reporter
l2_penalty=self.config.weight_decay,
)
pseudo_preds = pseudo_clf(
# b v d -> (b v) d
torch.cat([val_x0, val_x1]).flatten(0, 1)
).squeeze(-1)

# Edge case where the classifier learns to set its weights to zero
# Technically AUROC is not defined here but we "fill in" the value of 0.5
# since this is the limit as the weights approach zero
if not pseudo_preds.any():
return 0.5
else:
return roc_auc(pseudo_val, pseudo_preds).item()

def reset_parameters(self):
"""Reset the parameters of the probe.
Expand Down Expand Up @@ -231,6 +160,8 @@ def reset_parameters(self):

def forward(self, x: Tensor) -> Tensor:
"""Return the credence assigned to the hidden state `x`."""
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)

Expand Down Expand Up @@ -259,19 +190,22 @@ def fit(self, hiddens: Tensor) -> float:
x_neg, x_pos = hiddens.unbind(2)

# One-hot indicators for each prompt template
n, v, _ = x_neg.shape
n, v, d = x_neg.shape
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1)

self.norm.update(
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
self.norm.update(
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser

x_neg, x_pos = self.norm(x_neg), self.norm(x_pos)

# Record the best acc, loss, and params found so far
Expand Down Expand Up @@ -355,9 +289,3 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features, num_variants=self.num_variants)
torch.save(state, path)
31 changes: 31 additions & 0 deletions elk/training/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""An ELK reporter network."""

from dataclasses import dataclass

from concept_erasure import LeaceEraser
from simple_parsing.helpers import Serializable
from torch import Tensor, nn

from .platt_scaling import PlattMixin


@dataclass
class FitterConfig(Serializable, decode_into_subclasses=True):
seed: int = 42
"""The random seed to use."""


@dataclass
class Reporter(PlattMixin):
weight: Tensor
eraser: LeaceEraser

def __post_init__(self):
# Platt scaling parameters
self.bias = nn.Parameter(self.weight.new_zeros(1))
self.scale = nn.Parameter(self.weight.new_ones(1))

def __call__(self, hiddens: Tensor) -> Tensor:
"""Return the predicted log odds on input `x`."""
raw_scores = self.eraser(hiddens) @ self.weight.mT
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
Loading

0 comments on commit 4b1cc50

Please sign in to comment.