Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit logistic regression models on the GPU #103

Merged
merged 8 commits into from
Mar 7, 2023
Prev Previous commit
Next Next commit
Check how linearly separable the pseudo-labels are
  • Loading branch information
norabelrose committed Mar 5, 2023
commit 935006f07015cb386bb3c85e9841919fa372e52f
12 changes: 8 additions & 4 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
binary_cross_entropy_with_logits as bce_with_logits,
cross_entropy,
)
from typing import Optional
import torch


class Classifier(torch.nn.Module):
"""Linear classifier trained with supervised learning."""

def __init__(self, input_dim: int, num_classes: int = 1):
def __init__(
self, input_dim: int, num_classes: int = 1, device: Optional[str] = None
):
super().__init__()

self.linear = torch.nn.Linear(input_dim, num_classes)
self.linear = torch.nn.Linear(input_dim, num_classes, device=device)
self.linear.bias.data.zero_()
self.linear.weight.data.zero_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this made me slightly worried at first, but I guess it's always fine since its output is never used as input to another layer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was doing this because I think logistic regression models are usually initialized to zero? It shouldn't matter since this is always a convex problem, although maybe in weird cases it could matter due to the 1e-4 stopping condition


Expand All @@ -23,7 +26,7 @@ def fit(
x: torch.Tensor,
y: torch.Tensor,
*,
max_iter: int = 1000,
max_iter: int = 10_000,
) -> float:
"""Fit parameters to the given data with LBFGS."""

Expand All @@ -38,12 +41,13 @@ def fit(
num_classes = self.linear.out_features
loss_fn = bce_with_logits if num_classes == 1 else cross_entropy
loss = torch.inf
y = y.float()

def closure():
nonlocal loss
optimizer.zero_grad()

loss = loss_fn(self(x), y)
loss = loss_fn(self(x).squeeze(-1), y)
loss.backward()

return float(loss)
Expand Down
40 changes: 30 additions & 10 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from functools import partial
from pathlib import Path
from simple_parsing import Serializable
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from torch import Tensor
from tqdm.auto import tqdm
Expand All @@ -23,6 +22,7 @@
import random
import torch
import torch.multiprocessing as mp
import warnings


@dataclass
Expand Down Expand Up @@ -78,9 +78,30 @@ def train_reporter(
int16_to_float32(assert_type(Tensor, val[f"hidden_{layer}"])),
method=cfg.normalization,
)

x0, x1 = train_h.unbind(dim=-2)
val_x0, val_x1 = val_h.unbind(dim=-2)

# Check how linearly separable the pseudo-labels are. If they're very
# separable, the algorithm may not converge to a good solution.
pseudo_clf = Classifier(train_h.shape[-1], device=device)
pseudo_labels = torch.cat(
[
torch.zeros_like(train_labels),
torch.ones_like(train_labels),
]
)
pseudo_clf.fit(torch.cat([x0, x1]).squeeze(1), pseudo_labels)
with torch.no_grad():
pseudo_preds = pseudo_clf(torch.cat([val_x0, val_x1]).squeeze(1))
pseudo_auroc = roc_auc_score(pseudo_labels.cpu(), pseudo_preds.cpu())
if pseudo_auroc > 0.6:
warnings.warn(
f"The pseudo-labels at layer {layer} are linearly separable with "
f"an AUROC of {pseudo_auroc:.3f}. This may indicate that the "
f"algorithm will not converge to a good solution."
)

reporter = Reporter(x0.shape[-1], cfg.net, device=device)
if cfg.label_frac:
num_labels = round(cfg.label_frac * len(train_labels))
Expand All @@ -99,22 +120,21 @@ def train_reporter(

lr_dir.mkdir(parents=True, exist_ok=True)
reporter_dir.mkdir(parents=True, exist_ok=True)
stats = [layer, train_loss, *val_result]
stats = [layer, pseudo_auroc, train_loss, *val_result]

if not cfg.skip_baseline:
train_labels_aug = torch.cat([train_labels, 1 - train_labels])
val_labels_aug = torch.cat([val_labels, 1 - val_labels])
val_labels_aug = torch.cat([val_labels, 1 - val_labels]).cpu()

# TODO: Once we implement cross-validation for CCS, we should benchmark
# against LogisticRegressionCV here.
X = torch.cat([x0, x1]).squeeze()
d = X.shape[-1]
# lr_model = LogisticRegression(max_iter=10_000)
lr_model = Classifier(d)
lr_model = Classifier(d, device=device)
lr_model.fit(X.view(-1, d), train_labels_aug)

X_val = torch.cat([val_x0, val_x1]).squeeze()
lr_preds = lr_model(X_val).sigmoid().cpu().numpy()
with torch.no_grad():
lr_preds = lr_model(X_val).sigmoid().cpu()

lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5)
lr_auroc = roc_auc_score(val_labels_aug, lr_preds)

Expand Down Expand Up @@ -146,7 +166,7 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
devices = select_usable_devices(cfg.max_gpus)
num_devices = len(devices)

cols = ["layer", "train_loss", "loss", "acc", "cal_acc", "auroc"]
cols = ["layer", "pseudo_auroc", "train_loss", "loss", "acc", "cal_acc", "auroc"]
if not cfg.skip_baseline:
cols += ["lr_auroc", "lr_acc"]

Expand All @@ -163,6 +183,6 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None):
writer = csv.writer(f)
writer.writerow(cols)

mapper = pool.imap_unordered if num_devices > 1 else map
mapper = pool.imap if num_devices > 1 else map
for i, *stats in tqdm(mapper(fn, layers), total=len(layers)):
writer.writerow([i] + [f"{s:.4f}" for s in stats])