Skip to content

Commit

Permalink
Merge branch 'topo' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 14, 2024
2 parents 1f5477a + 9a9df31 commit 1921d47
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import wandb
from w2s.ds_registry import load_and_process_dataset
from w2s.knn import gather_hiddens, zeta_filter
from w2s.knn import gather_hiddens, topofilter
from w2s.loss import log_confidence_loss
from w2s.roc_auc import roc_auc

Expand Down Expand Up @@ -254,13 +254,13 @@ def strong_processor(examples):
train_acts = gather_hiddens(strong_model, strong_train)
torch.save(train_acts, acts_path)

w2s_train = strong_train.remove_columns("labels").add_column(
"labels", train_probs.numpy()
)
w2s_train = strong_train.remove_columns("labels")
w2s_train = w2s_train.add_column("labels", train_probs.numpy())

if cfg.contamination > 0.0:
y = train_probs.to(train_acts.device)
top = zeta_filter(train_acts, y, k=cfg.outlier_k, q=1.0 - cfg.contamination)
w2s_train = w2s_train.select(top.tolist())
indices = topofilter(train_acts, y, cfg.contamination, k=cfg.outlier_k)
w2s_train = w2s_train.select(indices)

# Check gt metrics every 100 steps during w2s training.
# We can overfit to the weak labels before a single epoch.
Expand Down
63 changes: 53 additions & 10 deletions w2s/knn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import torch
from datasets import Dataset
from scipy.sparse.csgraph import connected_components
from tqdm.auto import tqdm
from transformers import (
PretrainedConfig,
Expand All @@ -9,6 +11,57 @@
from .utils import assert_type


def lcc_mask(adj: torch.Tensor):
"""Mask for membership in the largest connected component"""
num_cmps, cmps = connected_components(adj.cpu(), connection="strong")
cmp_sizes = np.bincount(cmps, minlength=num_cmps)
return torch.from_numpy(cmps == cmp_sizes.argmax()).to(adj.device)


def topo_cc(x: torch.Tensor, y: torch.Tensor, *, k: int = 5):
"""TopoCC label filtering algorithm."""
# All pairwise distances, leaving out the diagonal
dists = torch.cdist(x, x).fill_diagonal_(torch.inf)

# Find indices of `k` nearest neighbors
indices = dists.topk(k, largest=False).indices

# Create kNN adjacency matrix
adj = indices.new_zeros(len(x), len(x), dtype=torch.bool)
adj.scatter_(1, indices, True)

cls_mask = y[:, None] > 0.5
pos_mask = lcc_mask(adj & cls_mask)
neg_mask = lcc_mask(adj & ~cls_mask)
return neg_mask | pos_mask


def topofilter(
x: torch.Tensor, y: torch.Tensor, contamination: float = 0.1, *, k: int = 5
):
"""Remove points whose labels are far the average of their neighbors' labels."""

C = topo_cc(x, y, k=k)
x_C, y_C = x[C], y[C]

# Zeta filtering
dists = torch.cdist(x_C, x_C).fill_diagonal_(torch.inf)
indices = dists.topk(k, largest=False).indices

# Compute how far each point is from its average neighbor
knn_labels = y_C[indices].float().mean(1)
dists = torch.abs(y_C - knn_labels)

# Remove points that are furthest from their average neighbor
cc_removed = len(x) - len(x_C)
remaining = round(len(x) * contamination) - cc_removed
n = max(remaining, 0)

filtered = dists.topk(n).indices.cpu()
C_indices = C.nonzero().squeeze(1).cpu()
return np.delete(C_indices, filtered)


@torch.no_grad()
def gather_hiddens(model: PreTrainedModel, dataset: Dataset):
dataset = dataset.with_format("torch", device="cuda")
Expand All @@ -34,16 +87,6 @@ def cummean(x):
)


def knn_average(x: torch.Tensor, y: torch.Tensor, k: int):
"""Compute average of `y` of `k` nearest neighbors of `x`."""

# Find indices of `k` nearest neighbors
indices = torch.cdist(x, x).topk(k, largest=False).indices

# Compute average of `y` of `k` nearest neighbors
return y[indices].mean(1)


def zeta_filter(x: torch.Tensor, y: torch.Tensor, *, k: int = 0, q: float = 0.5):
"""Remove points whose labels are far the average of their neighbors' labels."""

Expand Down

0 comments on commit 1921d47

Please sign in to comment.