Skip to content

Commit

Permalink
Use concept-erasure implementation of LEACE and SAL (#252)
Browse files Browse the repository at this point in the history
* Use concept-erasure implementation of LEACE and SAL

* fix parameter name in ccs

* Fix test failures

* Be picky about the concept-erasure version

* Refactor to support concept-erasure v0.1

* Fix test failure

---------

Co-authored-by: Walter Laurito <[email protected]>
  • Loading branch information
norabelrose and lauritowal committed Jul 10, 2023
1 parent 6f975ff commit a88c01a
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 519 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
args: ["-L fpr", "--skip=*.yaml"]
args: ["-L fpr,leace", "--skip=*.yaml"]
6 changes: 3 additions & 3 deletions elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .extraction import Extract, extract_hiddens
from .training import EigenReporter, EigenReporterConfig
from .training import EigenFitter, EigenFitterConfig
from .truncated_eigh import truncated_eigh

__all__ = [
"EigenReporter",
"EigenReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
Expand Down
4 changes: 1 addition & 3 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..utils import Color


Expand Down Expand Up @@ -40,8 +39,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()
reporter = torch.load(reporter_path, map_location=device)

row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, _) in val_output.items():
Expand Down
19 changes: 9 additions & 10 deletions elk/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .ccs_reporter import CcsConfig, CcsReporter
from .classifier import Classifier
from .concept_eraser import ConceptEraser
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .reporter import Reporter, ReporterConfig
from .common import FitterConfig
from .eigen_reporter import EigenFitter, EigenFitterConfig
from .platt_scaling import PlattMixin

__all__ = [
"CcsReporter",
"CcsReporterConfig",
"CcsConfig",
"Classifier",
"ConceptEraser",
"EigenReporter",
"EigenReporterConfig",
"Reporter",
"ReporterConfig",
"EigenFitter",
"EigenFitterConfig",
"FitterConfig",
"PlattMixin",
]
100 changes: 42 additions & 58 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,61 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
import torch.nn as nn
from concept_erasure import LeaceFitter
from torch import Tensor

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .concept_eraser import ConceptEraser
from .common import FitterConfig
from .losses import LOSSES
from .reporter import Reporter, ReporterConfig
from .platt_scaling import PlattMixin


@dataclass
class CcsReporterConfig(ReporterConfig):
"""
Args:
activation: The activation function to use. Defaults to GELU.
bias: Whether to use a bias term in the linear layers. Defaults to True.
hidden_size: The number of hidden units in the MLP. Defaults to None.
By default, use an MLP expansion ratio of 4/3. This ratio is used by
Tucker et al. (2022) <https://arxiv.org/abs/2204.09722> in their 3-layer
MLP probes. We could also use a ratio of 4, imitating transformer FFNs,
but this seems to lead to excessively large MLPs when num_layers > 2.
init: The initialization scheme to use. Defaults to "zero".
loss: The loss function to use. list of strings, each of the form
"coef*name", where coef is a float and name is one of the keys in
`elk.training.losses.LOSSES`.
Example: --loss 1.0*consistency_squared 0.5*prompt_var
corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var.
Defaults to the loss "ccs_squared_loss".
normalization: The kind of normalization to apply to the hidden states.
num_layers: The number of layers in the MLP. Defaults to 1.
pre_ln: Whether to include a LayerNorm module before the first linear
layer. Defaults to False.
supervised_weight: The weight of the supervised loss. Defaults to 0.0.
lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`.
Defaults to 1e-2.
num_epochs: The number of epochs to train for. Defaults to 1000.
num_tries: The number of times to try training the reporter. Defaults to 10.
optimizer: The optimizer to use. Defaults to "adam".
weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01.
"""

class CcsConfig(FitterConfig):
activation: Literal["gelu", "relu", "swish"] = "gelu"
"""The activation function to use."""
bias: bool = True
"""Whether to use a bias term in the linear layers."""
hidden_size: Optional[int] = None
"""
The number of hidden units in the MLP. Defaults to None. By default, use an MLP
expansion ratio of 4/3. This ratio is used by Tucker et al. (2022)
<https://arxiv.org/abs/2204.09722> in their 3-layer MLP probes. We could also use
a ratio of 4, imitating transformer FFNs, but this seems to lead to excessively
large MLPs when num_layers > 2.
"""
init: Literal["default", "pca", "spherical", "zero"] = "default"
"""The initialization scheme to use."""
loss: list[str] = field(default_factory=lambda: ["ccs"])
"""
The loss function to use. list of strings, each of the form "coef*name", where coef
is a float and name is one of the keys in `elk.training.losses.LOSSES`.
Example: `--loss 1.0*consistency_squared 0.5*prompt_var` corresponds to the loss
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
"""Whether to include a LayerNorm module before the first linear layer."""
supervised_weight: float = 0.0
"""The weight of the supervised loss."""

lr: float = 1e-2
"""The learning rate to use. Ignored when `optimizer` is `"lbfgs"`."""
num_epochs: int = 1000
"""The number of epochs to train for."""
num_tries: int = 10
"""The number of times to try training the reporter."""
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
"""The optimizer to use."""
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter
"""The weight decay or L2 penalty to use."""

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)
Expand All @@ -76,19 +66,19 @@ def __post_init__(self):
self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()]


class CcsReporter(Reporter):
class CcsReporter(nn.Module, PlattMixin):
"""CCS reporter network.
Args:
in_features: The number of input features.
cfg: The reporter configuration.
"""

config: CcsReporterConfig
config: CcsConfig

def __init__(
self,
cfg: CcsReporterConfig,
cfg: CcsConfig,
in_features: int,
*,
device: str | torch.device | None = None,
Expand All @@ -106,12 +96,7 @@ def __init__(

hidden_size = cfg.hidden_size or 4 * in_features // 3

self.norm = ConceptEraser(
in_features,
2 * num_variants,
device=device,
dtype=dtype,
)
self.norm = None
self.probe = nn.Sequential(
nn.Linear(
in_features,
Expand Down Expand Up @@ -175,6 +160,8 @@ def reset_parameters(self):

def forward(self, x: Tensor) -> Tensor:
"""Return the credence assigned to the hidden state `x`."""
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)

Expand Down Expand Up @@ -203,19 +190,22 @@ def fit(self, hiddens: Tensor) -> float:
x_neg, x_pos = hiddens.unbind(2)

# One-hot indicators for each prompt template
n, v, _ = x_neg.shape
n, v, d = x_neg.shape
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1)

self.norm.update(
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
self.norm.update(
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser

x_neg, x_pos = self.norm(x_neg), self.norm(x_pos)

# Record the best acc, loss, and params found so far
Expand Down Expand Up @@ -299,9 +289,3 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features, num_variants=self.num_variants)
torch.save(state, path)
31 changes: 31 additions & 0 deletions elk/training/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""An ELK reporter network."""

from dataclasses import dataclass

from concept_erasure import LeaceEraser
from simple_parsing.helpers import Serializable
from torch import Tensor, nn

from .platt_scaling import PlattMixin


@dataclass
class FitterConfig(Serializable, decode_into_subclasses=True):
seed: int = 42
"""The random seed to use."""


@dataclass
class Reporter(PlattMixin):
weight: Tensor
eraser: LeaceEraser

def __post_init__(self):
# Platt scaling parameters
self.bias = nn.Parameter(self.weight.new_zeros(1))
self.scale = nn.Parameter(self.weight.new_ones(1))

def __call__(self, hiddens: Tensor) -> Tensor:
"""Return the predicted log odds on input `x`."""
raw_scores = self.eraser(hiddens) @ self.weight.mT
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
113 changes: 0 additions & 113 deletions elk/training/concept_eraser.py

This file was deleted.

Loading

0 comments on commit a88c01a

Please sign in to comment.