diff --git a/elk/eigsh.py b/elk/eigsh.py index 93cc6675..10c1de60 100644 --- a/elk/eigsh.py +++ b/elk/eigsh.py @@ -13,7 +13,7 @@ def lanczos_eigsh( tol: Optional[float] = None, seed: Optional[int] = None, v0: Optional[Tensor] = None, - which: Literal["LA", "LM", "SA"] = "LM", + which: Literal["LA", "LM", "SA"] = "LA", ) -> tuple[Tensor, Tensor]: """Lanczos method for computing the top k eigenpairs of a symmetric matrix. @@ -21,6 +21,10 @@ def lanczos_eigsh( based on `scipy.sparse.linalg.eigsh`. Unlike the CuPy and SciPy functions, this function supports batched inputs with arbitrary leading dimensions. + Unlike the above implementations, we use which='LA' as the default instead of + which='LM' because we are interested in algebraic eigenvalues, not magnitude. + Largest magnitude is also harder to implement in TorchScript. + Args: A (Tensor): The matrix or batch of matrices of shape `[..., n, n]` for which to compute eigenpairs. Must be symmetric, but need not be positive definite. @@ -43,6 +47,20 @@ def lanczos_eigsh( *leading, n, m = A.shape assert n == m, "A must be a square matrix or a batch of square matrices." + # Short circuit if the matrix is too small; we can't outcompete the naive method. + if n <= 32: + L, Q = torch.linalg.eigh(A) + if which == "LA": + return L[..., -k:], Q[..., :, -k:] + elif which == "LM": + # Resort the eigenvalues and eigenvectors. + idx = L.abs().argsort(dim=-1, descending=True) + L = L.gather(-1, idx) + Q = Q.gather(-1, idx.unsqueeze(-1).expand(*idx.shape, n)) + return L[..., :k], Q[..., :, :k] + elif which == "SA": + return L[..., :k], Q[..., :, :k] + if ncv is None: ncv = min(max(2 * k, k + 32), n - 1) else: diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index b68a9ed4..f9dc35dc 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -101,7 +101,9 @@ def extract_hiddens( # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. - model = AutoModel.from_pretrained(cfg.model, torch_dtype="auto").to(device) + model = AutoModel.from_pretrained( + cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32 + ).to(device) # TODO: Maybe also make this configurable? # We want to make sure the answer is never truncated tokenizer = AutoTokenizer.from_pretrained(cfg.model, truncation_side="left") diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py index 54fc2292..c922ac4a 100644 --- a/elk/training/ccs_reporter.py +++ b/elk/training/ccs_reporter.py @@ -195,7 +195,11 @@ def loss( alpha = self.config.supervised_weight preds = p0.add(1 - p1).mul(0.5).squeeze(-1) - bce_loss = bce(preds, labels.type_as(preds)) + # broadcast the labels, and flatten the predictions + # so that both are 1D tensors + broadcast_labels = labels.repeat_interleave(preds.shape[1]).float() + flattened_preds = preds.cpu().flatten() + bce_loss = bce(flattened_preds, broadcast_labels.type_as(flattened_preds)) loss = alpha * bce_loss + (1 - alpha) * loss elif self.config.supervised_weight > 0: diff --git a/tests/test_eigsh.py b/tests/test_eigsh.py index 603c7e07..dc206d90 100644 --- a/tests/test_eigsh.py +++ b/tests/test_eigsh.py @@ -5,31 +5,29 @@ import torch +@pytest.mark.parametrize("n", [20, 40]) @pytest.mark.parametrize("which", ["LA", "SA"]) -def test_lanczos_eigsh(which): +def test_lanczos_eigsh(n, which): torch.manual_seed(42) - # Generate a random symmetric matrix - n = 10 A = torch.randn(n, n) A = A + A.T # Compute the top k eigenpairs using our implementation - k = 3 - w, v = lanczos_eigsh(A, k=k, which=which) + w, v = lanczos_eigsh(A, which=which) # Compute the top k eigenpairs using scipy - w_scipy, v_scipy = eigsh(A.numpy(), k=k, which=which) + w_scipy, v_scipy = eigsh(A.numpy(), which=which) # Check that the eigenvalues match to within the tolerance - assert np.allclose(w, w_scipy, rtol=1e-3) + torch.testing.assert_allclose(w, torch.from_numpy(w_scipy), atol=1e-3, rtol=1e-3) # Normalize the sign of the eigenvectors - for i in range(k): + for i in range(v.shape[-1]): if v[torch.argmax(torch.abs(v[:, i])), i] < 0: v[:, i] *= -1 if v_scipy[np.argmax(np.abs(v_scipy[:, i])), i] < 0: v_scipy[:, i] *= -1 # Check that the eigenvectors match to within the tolerance - assert np.allclose(v.numpy(), v_scipy, rtol=1e-3) + torch.testing.assert_allclose(v, torch.from_numpy(v_scipy), atol=1e-3, rtol=1e-3) diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py new file mode 100644 index 00000000..3211271c --- /dev/null +++ b/tests/test_smoke_elicit.py @@ -0,0 +1,63 @@ +from pathlib import Path + +from elk import ExtractionConfig +from elk.extraction import PromptConfig +from elk.training import CcsReporterConfig, EigenReporterConfig +from elk.training.train import train, RunConfig + +""" +TODO: These tests should work with deberta +but you'll need to make deberta fp32 instead of fp16 +because pytorch cpu doesn't support fp16 +""" + + +def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): + model_path = "sshleifer/tiny-gpt2" + dataset_name = "imdb" + config = RunConfig( + data=ExtractionConfig( + model=model_path, + prompts=PromptConfig(dataset=dataset_name, max_examples=[10]), + # run on all layers, tiny-gpt only has 2 layers + ), + net=CcsReporterConfig(), + ) + train(config, tmp_path) + # get the files in the tmp_path + files: Path = list(tmp_path.iterdir()) + created_file_names = {file.name for file in files} + expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + for file in expected_files: + assert file in created_file_names + + +def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): + """ + Currently this test fails with + u -= torch.einsum("...ij,...i->...j", V[..., :k, :], proj) + V[..., k, :] = F.normalize(u, dim=-1) + ~~~~~~~~~ <--- HERE + + u[:] = torch.einsum("...ij,...j->...i", A, V[..., k, :]) + + RuntimeError: select(): index 1 out of range for tensor of size [1, 2] + at dimension 0 + """ + model_path = "sshleifer/tiny-gpt2" + dataset_name = "imdb" + config = RunConfig( + data=ExtractionConfig( + model=model_path, + prompts=PromptConfig(dataset=dataset_name, max_examples=[10]), + # run on all layers, tiny-gpt only has 2 layers + ), + net=EigenReporterConfig(), + ) + train(config, tmp_path) + # get the files in the tmp_path + files: Path = list(tmp_path.iterdir()) + created_file_names = {file.name for file in files} + expected_files = ["cfg.yaml", "metadata.yaml", "lr_models", "reporters", "eval.csv"] + for file in expected_files: + assert file in created_file_names