Skip to content

Commit

Permalink
Refactor & rename lanczos_eigsh for convergence, correctness, & speed (
Browse files Browse the repository at this point in the history
…EleutherAI#164)

* Use a different default for ncv; throw an error when not converged

* truncated_eigh now works

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
norabelrose and pre-commit-ci[bot] committed Apr 7, 2023
1 parent be4980c commit 51fab16
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 272 deletions.
10 changes: 9 additions & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from .extraction import Extract, extract_hiddens
from .training import EigenReporter, EigenReporterConfig
from .truncated_eigh import truncated_eigh

__all__ = ["extract_hiddens", "Extract"]
__all__ = [
"EigenReporter",
"EigenReporterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
]
230 changes: 0 additions & 230 deletions elk/eigsh.py

This file was deleted.

21 changes: 14 additions & 7 deletions elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from dataclasses import dataclass
from typing import Optional
from warnings import warn

import torch
from torch import Tensor, nn, optim

from ..eigsh import lanczos_eigsh
from ..math_util import cov_mean_fused
from ..truncated_eigh import ConvergenceError, truncated_eigh
from .reporter import Reporter, ReporterConfig


Expand Down Expand Up @@ -169,20 +170,26 @@ def update(self, x_pos: Tensor, x_neg: Tensor) -> None:
self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2)
self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2)

def fit_streaming(self, warm_start: bool = False) -> float:
def fit_streaming(self) -> float:
"""Fit the probe using the current streaming statistics."""
A = (
self.config.var_weight * self.intercluster_cov
- self.config.inv_weight * self.intracluster_cov
- self.config.neg_cov_weight * self.contrastive_xcov
)
v0 = self.weight.T.squeeze() if warm_start else None

# We use "LA" (largest algebraic) instead of "LM" (largest magnitude) to
# ensure that the eigenvalue is positive and not a large negative one
L, Q = lanczos_eigsh(A, k=self.config.num_heads, v0=v0, which="LA")
self.weight.data = Q.T
try:
L, Q = truncated_eigh(A, k=self.config.num_heads)
except ConvergenceError:
warn(
"Truncated eigendecomposition failed to converge. Falling back on "
"PyTorch's dense eigensolver."
)

L, Q = torch.linalg.eigh(A)
L, Q = L[-self.config.num_heads :], Q[:, -self.config.num_heads :]

self.weight.data = Q.T
return -float(L[-1])

def fit(
Expand Down
Loading

0 comments on commit 51fab16

Please sign in to comment.