Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoke tests with tiny gpt2, fix CCSReporter #149

Merged
merged 10 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix lanczos_eigsh for small matrices
  • Loading branch information
norabelrose committed Mar 24, 2023
commit 89fbee5f27bd03d3959e470b52338edc83538f7e
20 changes: 19 additions & 1 deletion elk/eigsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ def lanczos_eigsh(
tol: Optional[float] = None,
seed: Optional[int] = None,
v0: Optional[Tensor] = None,
which: Literal["LA", "LM", "SA"] = "LM",
which: Literal["LA", "LM", "SA"] = "LA",
) -> tuple[Tensor, Tensor]:
"""Lanczos method for computing the top k eigenpairs of a symmetric matrix.

Implementation adapted from `cupyx.scipy.sparse.linalg.eigsh`, which in turn is
based on `scipy.sparse.linalg.eigsh`. Unlike the CuPy and SciPy functions, this
function supports batched inputs with arbitrary leading dimensions.

Unlike the above implementations, we use which='LA' as the default instead of
which='LM' because we are interested in algebraic eigenvalues, not magnitude.
Largest magnitude is also harder to implement in TorchScript.

Args:
A (Tensor): The matrix or batch of matrices of shape `[..., n, n]` for which to
compute eigenpairs. Must be symmetric, but need not be positive definite.
Expand All @@ -43,6 +47,20 @@ def lanczos_eigsh(
*leading, n, m = A.shape
assert n == m, "A must be a square matrix or a batch of square matrices."

# Short circuit if the matrix is too small; we can't outcompete the naive method.
if n <= 32:
L, Q = torch.linalg.eigh(A)
if which == "LA":
return L[..., -k:], Q[..., :, -k:]
elif which == "LM":
# Resort the eigenvalues and eigenvectors.
idx = L.abs().argsort(dim=-1, descending=True)
L = L.gather(-1, idx)
Q = Q.gather(-1, idx.unsqueeze(-1).expand(*idx.shape, n))
return L[..., :k], Q[..., :, :k]
elif which == "SA":
return L[..., :k], Q[..., :, :k]

if ncv is None:
ncv = min(max(2 * k, k + 32), n - 1)
else:
Expand Down
16 changes: 7 additions & 9 deletions tests/test_eigsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,29 @@
import torch


@pytest.mark.parametrize("n", [20, 40])
@pytest.mark.parametrize("which", ["LA", "SA"])
def test_lanczos_eigsh(which):
def test_lanczos_eigsh(n, which):
torch.manual_seed(42)

# Generate a random symmetric matrix
n = 10
A = torch.randn(n, n)
A = A + A.T

# Compute the top k eigenpairs using our implementation
k = 3
w, v = lanczos_eigsh(A, k=k, which=which)
w, v = lanczos_eigsh(A, which=which)

# Compute the top k eigenpairs using scipy
w_scipy, v_scipy = eigsh(A.numpy(), k=k, which=which)
w_scipy, v_scipy = eigsh(A.numpy(), which=which)

# Check that the eigenvalues match to within the tolerance
assert np.allclose(w, w_scipy, rtol=1e-3)
torch.testing.assert_allclose(w, torch.from_numpy(w_scipy), atol=1e-3, rtol=1e-3)

# Normalize the sign of the eigenvectors
for i in range(k):
for i in range(v.shape[-1]):
if v[torch.argmax(torch.abs(v[:, i])), i] < 0:
v[:, i] *= -1
if v_scipy[np.argmax(np.abs(v_scipy[:, i])), i] < 0:
v_scipy[:, i] *= -1

# Check that the eigenvectors match to within the tolerance
assert np.allclose(v.numpy(), v_scipy, rtol=1e-3)
torch.testing.assert_allclose(v, torch.from_numpy(v_scipy), atol=1e-3, rtol=1e-3)
3 changes: 0 additions & 3 deletions tests/test_smoke_elicit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pathlib import Path

import pytest

from elk import ExtractionConfig
from elk.extraction import PromptConfig
from elk.training import CcsReporterConfig, EigenReporterConfig
Expand Down Expand Up @@ -34,7 +32,6 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path):
assert file in created_file_names


@pytest.mark.skip(reason="Fix me: EigenReporter crashes with tiny gpt2")
def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path):
"""
Currently this test fails with
Expand Down