Skip to content

Commit

Permalink
Prompt var loss (EleutherAI#107)
Browse files Browse the repository at this point in the history
* fix bug when training on more than one variant

* cleanup

* add prompt variance loss; separate max_examples args for splits

* prompt_var loss

* rename math_util

* fix pseudo-label training and LR validation for >1 variant

---------

Co-authored-by: Nora Belrose <[email protected]>
Co-authored-by: Alex Mallen <alexm@jessica-a40-0.jessica-a40.tenant-eleutherai.svc.tenant.chi.local>
  • Loading branch information
3 people committed Mar 7, 2023
1 parent 51a674a commit e493136
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
4 changes: 2 additions & 2 deletions elk/extraction/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class PromptConfig(Serializable):
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset; if there is only one column
with a `ClassLabel` datatype, we use that.
max_examples: The maximum number of examples to use from the dataset. If zero,
use all examples. Defaults to 0.
max_examples: The maximum number of examples to use from the each split of
the dataset. If zero, use all examples. Defaults to 0.
num_shots: The number of examples to use in few-shot prompts. If zero, prompts
are zero-shot. Defaults to 0.
seed: The seed to use for prompt randomization. Defaults to 42.
Expand Down
34 changes: 33 additions & 1 deletion elk/training/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import Tensor
import math
import torch
import warnings


def H(p: Tensor) -> Tensor:
Expand Down Expand Up @@ -39,8 +40,39 @@ def ccs_squared_loss(logit0: Tensor, logit1: Tensor) -> Tensor:
The sum of the consistency and confidence losses.
"""
p0, p1 = logit0.sigmoid(), logit1.sigmoid()

consistency = p0.sub(1 - p1).square().mean()
confidence = torch.min(p0, p1).square().mean()

return consistency + confidence


def prompt_var_loss(logit0: Tensor, logit1: Tensor) -> Tensor:
"""
The prompt-variance CCS loss.
This is the original CCS loss with an additional term: the squared
difference between the probability of a proposition and the mean probability
over all variants of that proposition (templates).
The loss is symmetric, so it doesn't matter which argument is the original and
which is the negated proposition.
Args:
logit0: The log odds for the original proposition. shape ([batch,] n_variants)
logit1: The log odds for the negated proposition. shape ([batch,] n_variants)
Returns:
The sum of the negation consistency, confidence, and prompt invariance losses.
"""
assert logit0.shape == logit1.shape
assert len(logit0.shape) in [1, 2]
if logit0.shape[-1] == 1:
warnings.warn(
"Only one variant provided. Prompt variance loss will equal CCS loss."
)
p0, p1 = logit0.sigmoid(), logit1.sigmoid()
consistency = p0.sub(1 - p1).square().mean()
confidence = torch.min(p0, p1).square().mean()
mean_p0, mean_p1 = p0.mean(dim=-1, keepdim=True), p1.mean(dim=-1, keepdim=True)
prompt_variance = (mean_p0 - p0).square().mean() + (mean_p1 - p1).square().mean()
return consistency + confidence + prompt_variance
6 changes: 3 additions & 3 deletions elk/training/reporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""An ELK reporter network."""

from .losses import ccs_squared_loss, js_loss
from .losses import ccs_squared_loss, js_loss, prompt_var_loss
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -32,7 +32,6 @@ class ReporterConfig(Serializable):
Args:
activation: The activation function to use. Defaults to GELU.
bias: Whether to use a bias term in the linear layers. Defaults to True.
device: The device to use. Defaults to None, which means "current device".
hidden_size: The number of hidden units in the MLP. Defaults to None.
By default, use an MLP expansion ratio of 4/3. This ratio is used by
Tucker et al. (2022) <https://arxiv.org/abs/2204.09722> in their 3-layer
Expand All @@ -50,7 +49,7 @@ class ReporterConfig(Serializable):
bias: bool = True
hidden_size: Optional[int] = None
init: Literal["default", "spherical", "zero"] = "default"
loss: Literal["js", "squared"] = "squared"
loss: Literal["js", "squared", "prompt_var"] = "squared"
num_layers: int = 1
pre_ln: bool = False
seed: int = 42
Expand Down Expand Up @@ -124,6 +123,7 @@ def __init__(
self.unsupervised_loss = {
"js": js_loss,
"squared": ccs_squared_loss,
"prompt_var": prompt_var_loss,
}[cfg.loss]
self.supervised_weight = cfg.supervised_weight

Expand Down
14 changes: 11 additions & 3 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .reporter import OptimConfig, Reporter, ReporterConfig
from dataclasses import dataclass
from datasets import DatasetDict
from einops import rearrange
from functools import partial
from pathlib import Path
from simple_parsing import Serializable
Expand Down Expand Up @@ -90,10 +91,17 @@ def train_reporter(
torch.zeros_like(train_labels),
torch.ones_like(train_labels),
]
).repeat_interleave(
x0.shape[1]
) # make num_variants copies of each pseudo-label

pseudo_clf.fit(
rearrange(torch.cat([x0, x1]), "b v d -> (b v) d"), pseudo_labels
)
pseudo_clf.fit(torch.cat([x0, x1]).squeeze(1), pseudo_labels)
with torch.no_grad():
pseudo_preds = pseudo_clf(torch.cat([val_x0, val_x1]).squeeze(1))
pseudo_preds = pseudo_clf(
rearrange(torch.cat([val_x0, val_x1]), "b v d -> (b v) d")
)
pseudo_auroc = roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu())
if pseudo_auroc > 0.6:
warnings.warn(
Expand Down Expand Up @@ -138,7 +146,7 @@ def train_reporter(
lr_model = Classifier(d, device=device)
lr_model.fit(X.view(-1, d), train_labels_aug)

X_val = torch.cat([val_x0, val_x1]).squeeze()
X_val = torch.cat([val_x0, val_x1]).view(-1, d)
with torch.no_grad():
lr_preds = lr_model(X_val).sigmoid().cpu()

Expand Down

0 comments on commit e493136

Please sign in to comment.