Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple choice datasets #179

Merged
merged 54 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
d292c7c
LM output evaluation for autoregressive models
norabelrose Apr 4, 2023
7ed5ccd
move to own baseline file
lauritowal Apr 4, 2023
ba1d3b2
cleanup
lauritowal Apr 4, 2023
a20d4ca
Support encoder-decoder model LM output
norabelrose Apr 5, 2023
088758e
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 5, 2023
77d7418
isort
norabelrose Apr 5, 2023
5bf63f4
Bug fixes
norabelrose Apr 5, 2023
819cfed
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
d3d9a8d
Merge branch 'main' into lm-output
norabelrose Apr 5, 2023
b89e23c
Remove test_log_csv_elements
norabelrose Apr 5, 2023
9aef842
Remove Python 3.9 support
norabelrose Apr 5, 2023
0851d4f
Add Pandas to pyproject.toml
norabelrose Apr 5, 2023
207a375
add code (contains still same device cuda error)
lauritowal Apr 5, 2023
e7efcce
fix multiple cuda error, save evals to right folder + cleanup
lauritowal Apr 7, 2023
b5fa54c
Merge branch 'main' into eval_lr
lauritowal Apr 7, 2023
4f8bdc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
9ca72ba
Fix bug noticed by Waree
norabelrose Apr 7, 2023
d7e4893
Merge remote-tracking branch 'origin/eval_lr' into lm-output
norabelrose Apr 7, 2023
bcdca8a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 7, 2023
713a251
Add sanity check to load_prompts and refactor binarize
norabelrose Apr 7, 2023
0c35bc7
Changing a ton of stuff
norabelrose Apr 8, 2023
f6a762a
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f547744
Revert changes to binarize
norabelrose Apr 10, 2023
ab1909f
Stupid prompt_counter bug
norabelrose Apr 10, 2023
f58290f
Merge remote-tracking branch 'origin/main' into lm-output
norabelrose Apr 10, 2023
f912ee6
Remove stupid second set_start_method call
norabelrose Apr 10, 2023
606dcad
Merge remote-tracking branch 'origin/lm-output' into multiclass
norabelrose Apr 10, 2023
0038792
Merge remote-tracking branch 'origin/main' into multiclass
norabelrose Apr 10, 2023
83b480b
Fix bugs in binary case
norabelrose Apr 11, 2023
3e66262
Various little refactors
norabelrose Apr 11, 2023
a8c21a6
Remove .predict and .predict_prob on Reporter; trying to get SciQ to …
norabelrose Apr 11, 2023
5f478b1
Bugfix for Reporter.score on binary tasks
norabelrose Apr 11, 2023
97b26ac
Fix bug where cached hidden states aren’t used when num_gpus is diffe…
norabelrose Apr 12, 2023
11fda87
Actually works now
norabelrose Apr 12, 2023
da4c72f
Refactor handling of multiple datasets
norabelrose Apr 13, 2023
e1675f7
Various fixes
norabelrose Apr 13, 2023
8cc325b
Merge remote-tracking branch 'origin/main' into multi-ds-eval
norabelrose Apr 13, 2023
14987e1
Fix math tests
norabelrose Apr 13, 2023
88683fa
Fix smoke tests
norabelrose Apr 13, 2023
a6c382e
All tests working ostensibly
norabelrose Apr 13, 2023
ecc53cb
Make CCS normalization customizable
norabelrose Apr 13, 2023
18c7f4c
log each dataset individually
AlexTMallen Apr 13, 2023
94a900c
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 13, 2023
5173649
Fix label_column bug
norabelrose Apr 13, 2023
3e6c39c
GLUE MNLI works on Deberta
norabelrose Apr 14, 2023
1e9ce06
Move pseudo AUROC stuff to CcsReporter
norabelrose Apr 14, 2023
35a8f34
Make 'datasets' and 'label_columns' config options more opinionated
norabelrose Apr 14, 2023
615bbb1
tiny spacing change
norabelrose Apr 14, 2023
f021404
Allow for toggling CV
norabelrose Apr 14, 2023
f6629ec
Merge branch 'multi-ds-eval' into multiclass
norabelrose Apr 14, 2023
99f01c3
Remove duplicate dbpedia template
norabelrose Apr 14, 2023
f415f8d
Merge branch 'main' into multiclass
norabelrose Apr 14, 2023
d16c96b
Training on datasets with different numbers of classes now works
norabelrose Apr 15, 2023
6e205a7
Update extraction.py
lauritowal Apr 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Training on datasets with different numbers of classes now works
  • Loading branch information
norabelrose committed Apr 15, 2023
commit d16c96b943b72b1f33e74f255518690f2f4e5d30
2 changes: 2 additions & 0 deletions elk/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .ccs_reporter import CcsReporter, CcsReporterConfig
from .classifier import Classifier
from .eigen_reporter import EigenReporter, EigenReporterConfig
from .normalizer import Normalizer
from .reporter import OptimConfig, Reporter, ReporterConfig

__all__ = [
"CcsReporter",
"CcsReporterConfig",
"Classifier",
"EigenReporter",
"EigenReporterConfig",
"Normalizer",
Expand Down
57 changes: 36 additions & 21 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,27 @@ class EigenReporter(Reporter):
"""A linear reporter whose weights are computed via eigendecomposition.

Args:
in_features: The number of input features.
cfg: The reporter configuration.
in_features: The number of input features.
num_classes: The number of classes for tracking the running means. If `None`,
we don't track the running means at all, and the semantics of `update()`
are a bit different. In particular, each call to `update()` is treated as a
new dataset, with a potentially different number of classes. The covariance
matrices are simply averaged over each batch of data passed to `update()`,
instead of being updated with Welford's algorithm. This is useful for
training a single reporter on multiple datasets, where the number of
classes may vary.

Attributes:
config: The reporter configuration.
intercluster_cov_M2: The running sum of the covariance matrices of the
centroids of the positive and negative clusters.
intercluster_cov_M2: The unnormalized covariance matrix averaged over all
classes.
intracluster_cov: The running mean of the covariance matrices within each
cluster. This doesn't need to be a running sum because it's doesn't use
Welford's algorithm.
contrastive_xcov_M2: The running sum of the cross-covariance between the
centroids of the positive and negative clusters.
n: The running sum of the number of samples in the positive and negative
clusters.
contrastive_xcov_M2: Average of the unnormalized cross-covariance matrices
across all pairs of classes (k, k').
n: The running sum of the number of clusters processed by `update()`.
weight: The reporter weight matrix. Guaranteed to always be orthogonal, and
the columns are sorted in descending order of eigenvalue magnitude.
"""
Expand All @@ -64,16 +71,17 @@ class EigenReporter(Reporter):
intracluster_cov: Tensor # invariance
contrastive_xcov_M2: Tensor # negative covariance
n: Tensor
class_means: Tensor
class_means: Tensor | None
weight: Tensor

def __init__(
self,
cfg: EigenReporterConfig,
in_features: int,
num_classes: int = 2,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
num_classes: int | None = 2,
*,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.config = cfg
Expand All @@ -86,7 +94,11 @@ def __init__(
self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long))
self.register_buffer(
"class_means",
torch.zeros(num_classes, in_features, device=device, dtype=dtype),
(
torch.zeros(num_classes, in_features, device=device, dtype=dtype)
if num_classes is not None
else None
),
)

self.register_buffer(
Expand Down Expand Up @@ -148,10 +160,6 @@ def update(self, hiddens: Tensor) -> None:
assert k > 1, "Must provide at least two hidden states"
assert hiddens.ndim == 4, "Must be of shape [batch, variants, choices, dim]"

# We don't actually call super because we need access to the earlier estimate
# of the population mean in order to update (cross-)covariances properly
# super().update(hiddens)

self.n += n

# *** Invariance (intra-cluster) ***
Expand All @@ -164,21 +172,28 @@ def update(self, hiddens: Tensor) -> None:
centroids = hiddens.mean(1)
deltas, deltas2 = [], []

# Iterating over classes
for i, h in enumerate(centroids.unbind(1)):
# Update the running means; super().update() does this usually
delta = h - self.class_means[i]
self.class_means[i] += delta.sum(dim=0) / self.n
# Update the running means if needed
if self.class_means is not None:
delta = h - self.class_means[i]
self.class_means[i] += delta.sum(dim=0) / self.n

# Post-mean update deltas are used to update the (co)variance
delta2 = h - self.class_means[i] # [n, d]
else:
delta = h - h.mean(dim=0)
delta2 = delta

# *** Variance (inter-cluster) ***
# See code at https://bit.ly/3YC9BhH and "Welford's online algorithm"
# in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
# Post-mean update deltas are used to update the (co)variance
delta2 = h - self.class_means[i] # [n, d]
self.intercluster_cov_M2.addmm_(delta.mT, delta2, alpha=1 / k)
deltas.append(delta)
deltas2.append(delta2)

# *** Negative covariance (contrastive) ***
# Iterating over pairs of classes (k, k') where k != k'
for i, d in enumerate(deltas):
for j, d_ in enumerate(deltas2):
# Compare to the other classes only
Expand Down
40 changes: 22 additions & 18 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,46 +85,48 @@ def train_reporter(
train_dict = self.prepare_data(device, layer, "train")
val_dict = self.prepare_data(device, layer, "val")

(train_h, train_labels, _), *rest = train_dict.values()
(n, v, k, d) = train_h.shape
(first_train_h, train_labels, _), *rest = train_dict.values()
d = first_train_h.shape[-1]
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
raise ValueError("All datasets must have the same hidden state size")

if not all(other_h.shape[2] == k for other_h, _, _ in rest):
raise ValueError("All datasets must have the same number of classes")

# Can't figure out a way to make this line less ugly
next(iter(train_dict.values()))[0].shape[-1]
reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
if isinstance(self.cfg.net, CcsReporterConfig):
assert len(train_dict) == 1, "CCS only supports single-task training"

reporter = CcsReporter(self.cfg.net, d, device=device)
train_loss = reporter.fit(train_h, train_labels)
train_loss = reporter.fit(first_train_h, train_labels)

(val_h, val_gt, _) = next(iter(val_dict.values()))
x0, x1 = train_h.unbind(2)
x0, x1 = first_train_h.unbind(2)
val_x0, val_x1 = val_h.unbind(2)
pseudo_auroc = reporter.check_separability(
train_pair=(x0, x1), val_pair=(val_x0, val_x1)
)

elif isinstance(self.cfg.net, EigenReporterConfig):
# To enable training on multiple tasks with different numbers of variants,
# we update the statistics in a streaming fashion and then fit
reporter = EigenReporter(self.cfg.net, d, k, device=device)
# We set num_classes to None to enable training on datasets with different
# numbers of classes. Under the hood, this causes the covariance statistics
# to be simply averaged across all batches passed to update().
reporter = EigenReporter(self.cfg.net, d, num_classes=None, device=device)

hidden_list, label_list = [], []
for ds_name, (train_h, train_labels, _) in train_dict.items():
hidden_list.append(train_h)
label_list.append(train_labels)
(_, v, k, _) = train_h.shape

# Datasets can have different numbers of variants and different numbers
# of classes, so we need to flatten them here before concatenating
hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d"))
label_list.append(
to_one_hot(repeat(train_labels, "n -> (n v)", v=v), k).flatten()
)
reporter.update(train_h)

pseudo_auroc = None
train_loss = reporter.fit_streaming()
reporter.platt_scale(
to_one_hot(
repeat(torch.cat(label_list), "n -> (n v)", v=v), k
).flatten(),
rearrange(torch.cat(hidden_list), "n v k d -> (n v k) d"),
torch.cat(label_list),
torch.cat(hidden_list),
)
else:
raise ValueError(f"Unknown reporter config type: {type(self.cfg.net)}")
Expand All @@ -148,6 +150,8 @@ def train_reporter(
val_result = reporter.score(val_gt, val_h)

if val_lm_preds is not None:
(_, v, k, _) = val_h.shape

val_gt_cpu = repeat(val_gt, "n -> (n v)", v=v).cpu()
val_lm_preds = rearrange(val_lm_preds, "n v ... -> (n v) ...")
val_lm_auroc = roc_auc_score(
Expand Down
72 changes: 51 additions & 21 deletions tests/test_eigen_reporter.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,69 @@
import pytest
import torch

from elk.training import EigenReporter, EigenReporterConfig
from elk.utils import batch_cov, cov_mean_fused


def test_eigen_reporter():
@pytest.mark.parametrize("track_class_means", [True, False])
def test_eigen_reporter(track_class_means: bool):
cluster_size = 5
hidden_size = 10
num_clusters = 100
N = 100

x = torch.randn(num_clusters, cluster_size, 2, hidden_size, dtype=torch.float64)
x = torch.randn(N, cluster_size, 2, hidden_size, dtype=torch.float64)
x1, x2 = x.chunk(2, dim=0)
x_neg, x_pos = x.unbind(2)

reporter = EigenReporter(EigenReporterConfig(), hidden_size, dtype=torch.float64)
reporter = EigenReporter(
EigenReporterConfig(),
hidden_size,
dtype=torch.float64,
num_classes=2 if track_class_means else None,
)
reporter.update(x1)
reporter.update(x2)

# Check that the streaming mean is correct
x_neg, x_pos = x.unbind(2)
pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1))
torch.testing.assert_close(reporter.class_means[0], neg_mu)
torch.testing.assert_close(reporter.class_means[1], pos_mu)
if track_class_means:
# Check that the streaming mean is correct
neg_mu, pos_mu = x_neg.mean(dim=(0, 1)), x_pos.mean(dim=(0, 1))

# Check that the streaming covariance is correct
pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1)
expected_var = 0.5 * (batch_cov(pos_centroids) + batch_cov(neg_centroids))
torch.testing.assert_close(reporter.intercluster_cov, expected_var)
assert reporter.class_means is not None
torch.testing.assert_close(reporter.class_means[0], neg_mu)
torch.testing.assert_close(reporter.class_means[1], pos_mu)

# Check that the streaming invariance (intra-cluster variance) is correct
expected_invariance = 0.5 * (cov_mean_fused(x_pos) + cov_mean_fused(x_neg))
torch.testing.assert_close(reporter.intracluster_cov, expected_invariance)
# Check that the streaming covariance is correct
neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1)
true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids))
torch.testing.assert_close(reporter.intercluster_cov, true_cov)

# Check that the streaming negative covariance is correct
true_xcov = (neg_centroids - neg_mu).mT @ (pos_centroids - pos_mu) / N
true_xcov = 0.5 * (true_xcov + true_xcov.mT)
torch.testing.assert_close(reporter.contrastive_xcov, true_xcov)
else:
assert reporter.class_means is None

# Check that the streaming negative covariance is correct
cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters
cross_cov = 0.5 * (cross_cov + cross_cov.mT)
torch.testing.assert_close(reporter.contrastive_xcov, cross_cov)
# Check that the covariance matrices are correct. When we don't track class
# means, we expect intercluster_cov and contrastive_xcov to simply be averaged
# over each batch passed to update().
true_xcov = 0.0
true_cov = 0.0
for x_i in (x1, x2):
x_neg_i, x_pos_i = x_i.unbind(2)
neg_centroids, pos_centroids = x_neg_i.mean(dim=1), x_pos_i.mean(dim=1)
true_cov += 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids))

neg_mu_i, pos_mu_i = x_neg_i.mean(dim=(0, 1)), x_pos_i.mean(dim=(0, 1))
xcov_asym = (neg_centroids - neg_mu_i).mT @ (pos_centroids - pos_mu_i)
true_xcov += 0.5 * (xcov_asym + xcov_asym.mT)

torch.testing.assert_close(reporter.intercluster_cov, true_cov / 2)
torch.testing.assert_close(reporter.contrastive_xcov, true_xcov / N)

# Check that the streaming invariance (intra-cluster variance) is correct.
# This is actually the same whether or not we track class means.
expected_invariance = 0.5 * (cov_mean_fused(x_neg) + cov_mean_fused(x_pos))
torch.testing.assert_close(reporter.intracluster_cov, expected_invariance)

assert reporter.n == num_clusters
assert reporter.n == N