Skip to content

Commit

Permalink
Make Classifier great again (with better hparams)
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 19, 2023
1 parent 9909466 commit 3b4592c
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,16 @@ def __init__(
self.linear.weight.data.zero_()

def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
return self.linear(x).squeeze(-1)

@torch.enable_grad()
def fit(
self,
x: Tensor,
y: Tensor,
*,
l2_penalty: float = 0.1,
l2_penalty: float = 0.0,
max_iter: int = 10_000,
tol: float = 1e-4,
) -> float:
"""Fits the model to the input data using L-BFGS with L2 regularization.
Expand All @@ -68,17 +67,14 @@ def fit(
multiclass classification, where C is the number of classes.
l2_penalty: L2 regularization strength.
max_iter: Maximum number of iterations for the L-BFGS optimizer.
tol: Tolerance for the L-BFGS optimizer.
Returns:
Final value of the loss function after optimization.
"""

optimizer = torch.optim.LBFGS(
self.parameters(),
line_search_fn="strong_wolfe",
max_iter=max_iter,
tolerance_grad=tol,
)

num_classes = self.linear.out_features
Expand All @@ -95,13 +91,12 @@ def closure():
# Calculate the loss function
logits = self(x).squeeze(-1)
loss = loss_fn(logits, y)
if l2_penalty:
reg_loss = loss + l2_penalty * self.linear.weight.square().sum()
else:
reg_loss = loss

# Add L2 regularization penalty the way scikit-learn does
l2_reg = 0.5 * self.linear.weight.square().sum()

reg_loss = loss + l2_penalty * l2_reg
reg_loss.backward()

return float(reg_loss)

optimizer.step(closure)
Expand All @@ -117,7 +112,6 @@ def fit_cv(
max_iter: int = 10_000,
num_penalties: int = 10,
seed: int = 42,
tol: float = 1e-4,
) -> RegularizationPath:
"""Fit using k-fold cross-validation to select the best L2 penalty.
Expand All @@ -130,7 +124,6 @@ def fit_cv(
max_iter: Maximum number of iterations for the L-BFGS optimizer.
num_penalties: Number of L2 regularization penalties to try.
seed: Random seed for the k-fold cross-validation.
tol: Tolerance for the L-BFGS optimizer.
Returns:
`RegularizationPath` containing the penalties tried and the validation loss
Expand All @@ -148,11 +141,12 @@ def fit_cv(
fold_size = num_samples // k
indices = torch.randperm(num_samples, device=x.device, generator=rng)

l2_penalties = torch.logspace(-4, 4, num_penalties).tolist()
# Try a range of L2 penalties, including 0
l2_penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist()

num_classes = self.linear.out_features
loss_fn = bce_with_logits if num_classes == 1 else cross_entropy
losses = x.new_zeros((k, num_penalties))
losses = x.new_zeros((k, num_penalties + 1))
y = y.to(
torch.get_default_dtype() if num_classes == 1 else torch.long,
)
Expand All @@ -167,9 +161,7 @@ def fit_cv(

# Regularization path with warm-starting
for j, l2_penalty in enumerate(l2_penalties):
self.fit(
train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter, tol=tol
)
self.fit(train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter)

logits = self(val_x).squeeze(-1)
loss = loss_fn(logits, val_y)
Expand All @@ -180,7 +172,7 @@ def fit_cv(

# Refit with the best penalty
best_penalty = l2_penalties[best_idx]
self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter, tol=tol)
self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter)
return RegularizationPath(l2_penalties, mean_losses.tolist())

def nullspace_project(self, x: Tensor) -> Tensor:
Expand Down

0 comments on commit 3b4592c

Please sign in to comment.