From 3b4592c75c262829bbb1a2bd48ee1898ec493f9d Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Wed, 19 Apr 2023 05:57:11 +0000 Subject: [PATCH] Make Classifier great again (with better hparams) --- elk/training/classifier.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/elk/training/classifier.py b/elk/training/classifier.py index a140af44..b92d0f7e 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -47,7 +47,7 @@ def __init__( self.linear.weight.data.zero_() def forward(self, x: Tensor) -> Tensor: - return self.linear(x) + return self.linear(x).squeeze(-1) @torch.enable_grad() def fit( @@ -55,9 +55,8 @@ def fit( x: Tensor, y: Tensor, *, - l2_penalty: float = 0.1, + l2_penalty: float = 0.0, max_iter: int = 10_000, - tol: float = 1e-4, ) -> float: """Fits the model to the input data using L-BFGS with L2 regularization. @@ -68,17 +67,14 @@ def fit( multiclass classification, where C is the number of classes. l2_penalty: L2 regularization strength. max_iter: Maximum number of iterations for the L-BFGS optimizer. - tol: Tolerance for the L-BFGS optimizer. Returns: Final value of the loss function after optimization. """ - optimizer = torch.optim.LBFGS( self.parameters(), line_search_fn="strong_wolfe", max_iter=max_iter, - tolerance_grad=tol, ) num_classes = self.linear.out_features @@ -95,13 +91,12 @@ def closure(): # Calculate the loss function logits = self(x).squeeze(-1) loss = loss_fn(logits, y) + if l2_penalty: + reg_loss = loss + l2_penalty * self.linear.weight.square().sum() + else: + reg_loss = loss - # Add L2 regularization penalty the way scikit-learn does - l2_reg = 0.5 * self.linear.weight.square().sum() - - reg_loss = loss + l2_penalty * l2_reg reg_loss.backward() - return float(reg_loss) optimizer.step(closure) @@ -117,7 +112,6 @@ def fit_cv( max_iter: int = 10_000, num_penalties: int = 10, seed: int = 42, - tol: float = 1e-4, ) -> RegularizationPath: """Fit using k-fold cross-validation to select the best L2 penalty. @@ -130,7 +124,6 @@ def fit_cv( max_iter: Maximum number of iterations for the L-BFGS optimizer. num_penalties: Number of L2 regularization penalties to try. seed: Random seed for the k-fold cross-validation. - tol: Tolerance for the L-BFGS optimizer. Returns: `RegularizationPath` containing the penalties tried and the validation loss @@ -148,11 +141,12 @@ def fit_cv( fold_size = num_samples // k indices = torch.randperm(num_samples, device=x.device, generator=rng) - l2_penalties = torch.logspace(-4, 4, num_penalties).tolist() + # Try a range of L2 penalties, including 0 + l2_penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist() num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy - losses = x.new_zeros((k, num_penalties)) + losses = x.new_zeros((k, num_penalties + 1)) y = y.to( torch.get_default_dtype() if num_classes == 1 else torch.long, ) @@ -167,9 +161,7 @@ def fit_cv( # Regularization path with warm-starting for j, l2_penalty in enumerate(l2_penalties): - self.fit( - train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter, tol=tol - ) + self.fit(train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter) logits = self(val_x).squeeze(-1) loss = loss_fn(logits, val_y) @@ -180,7 +172,7 @@ def fit_cv( # Refit with the best penalty best_penalty = l2_penalties[best_idx] - self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter, tol=tol) + self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter) return RegularizationPath(l2_penalties, mean_losses.tolist()) def nullspace_project(self, x: Tensor) -> Tensor: