Skip to content

Commit

Permalink
Merge pull request #5 from EleutherAI/quadratic
Browse files Browse the repository at this point in the history
Quadratic LEACE
  • Loading branch information
norabelrose committed Aug 28, 2023
2 parents 1547093 + bfee3a4 commit 0d4dbcb
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 8 deletions.
8 changes: 7 additions & 1 deletion concept_erasure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from .concept_scrubber import ConceptScrubber
from .groupby import GroupedTensor, groupby
from .leace import ErasureMethod, LeaceEraser, LeaceFitter
from .oracle import OracleEraser, OracleFitter
from .quadratic import QuadraticEraser, QuadraticFitter
from .shrinkage import optimal_linear_shrinkage
from .utils import assert_type

__all__ = [
"assert_type",
"groupby",
"optimal_linear_shrinkage",
"ConceptScrubber",
"ErasureMethod",
"GroupedTensor",
"LeaceEraser",
"LeaceFitter",
"OracleEraser",
"OracleFitter",
"ErasureMethod",
"QuadraticEraser",
"QuadraticFitter",
]
107 changes: 107 additions & 0 deletions concept_erasure/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from dataclasses import dataclass
from typing import Callable, Iterator

import torch
from torch import LongTensor, Tensor


@dataclass(frozen=True)
class GroupedTensor:
"""A tensor split into groups along a given dimension.
This class contains all the information needed to reconstruct the original tensor,
or to take a list of tensors derived from the groups and coalesce them in such a
way that the original order is restored.
"""

dim: int
"""Dimension along which the tensor was split."""

groups: list[Tensor]
"""List of tensors such that `groups[i]` contains all elements of `x` whose group
label is `labels[i]`."""

indices: LongTensor
"""Indices used to sort the original tensor."""

labels: list[int]
"""Unique label for each element of `groups`."""

def coalesce(self, groups: list[Tensor] | None = None) -> Tensor:
"""Fuse `groups or self.groups` back together, restoring the original order.
This method is most useful when you want to group a tensor, perform an operation
on each group, then combine the results back together.
"""
if groups is None:
groups = self.groups

# First concatenate the groups back together
fused = torch.cat(groups, dim=self.dim)

# Invert the permutation to restore the original order
return fused.index_select(self.dim, invert_indices(self.indices))

def map(self, fn: Callable[[int, Tensor], Tensor]) -> "GroupedTensor":
"""Apply `fn` to each group & return a new `GroupedTensor` with the results."""
results = [fn(label, group) for label, group in zip(self.labels, self.groups)]
return GroupedTensor(self.dim, results, self.indices, self.labels)

def __iter__(self) -> Iterator[tuple[int, Tensor]]:
"""Iterate over the groups and their labels."""
for label, group in zip(self.labels, self.groups):
yield label, group


def groupby(
x: Tensor, key: Tensor, dim: int = 0, *, stable: bool = False
) -> GroupedTensor:
"""Efficiently split `x` into groups along `dim` according to `key`.
This function is intended to mimic the behavior of `itertools.groupby`, but for
PyTorch tensors. Under the hood, we sort `x` by `key` once, then return views
onto the sorted tensor in order to minimize the number of memcpy and equality
checking operations performed.
By necessity this operation performs a host-device sync since we need to know
the number of groups and their sizes in order to create a view for each.
Args:
x: Tensor to split into groups.
key: Tensor of group labels.
dim: Dimension along which to split `x`.
stable: If `True`, use a stable sorting algorithm. This is slower but ensures
that the order of elements within each group is preserved.
Returns:
A `GroupedTensor` containing the groups, sorting indices, and labels.
"""
assert key.dtype == torch.int64, "`key` must be int64"
assert key.ndim == 1, "`key` must be 1D"

key, indices = key.sort(stable=stable)
labels, counts = key.unique_consecutive(return_counts=True)

# Sort `x` by `key` along `dim`
x = x.index_select(dim, indices)
groups = x.split(counts.tolist(), dim=dim)

return GroupedTensor(dim, groups, indices, labels.tolist())


@torch.jit.script
def invert_indices(indices: Tensor) -> Tensor:
"""Efficiently invert the permutation represented by `indices`.
Example:
>>> indices = torch.tensor([2, 0, 1])
>>> invert_indices(indices)
tensor([1, 2, 0])
"""
# Create an empty tensor to hold the reverse permutation
reverse_indices = torch.empty_like(indices)

# Scatter the indices to reverse the permutation
reverse_indices.scatter_(0, indices, torch.arange(len(indices)))

return reverse_indices
137 changes: 137 additions & 0 deletions concept_erasure/optimal_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch import Tensor

from .shrinkage import trace


def is_positive_definite(A: Tensor) -> Tensor:
"""Efficiently check if `A` is p.d. by attempting Cholesky decomposition."""
return torch.linalg.cholesky_ex(A).info.eq(0)


@torch.jit.script
def psd_sqrt(A: Tensor) -> Tensor:
"""Compute the unique p.s.d. square root of a positive semidefinite matrix."""
L, U = torch.linalg.eigh(A)
L = L[..., None, :].clamp_min(0.0)
return U * L.sqrt() @ U.mT


def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]:
"""Efficiently compute both the p.s.d. sqrt & pinv sqrt of p.s.d. matrix `A`."""
L, U = torch.linalg.eigh(A)
L = L[..., None, :].clamp_min(0.0)

# Square root is easy
sqrt = U * L.sqrt() @ U.mT

# We actually compute the pseudo-inverse here for numerical stability.
# Use the same heuristic as `torch.linalg.pinv` to determine the tolerance.
thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps
rsqrt = U * L.rsqrt().where(L > thresh, 0.0) @ U.mT

return sqrt, rsqrt


def ot_barycenter(
Ks: Tensor, weights: Tensor | None = None, *, max_iter: int = 100
) -> Tensor:
"""Fixed-point iteration for the 2-Wasserstein barycenter of a set of Gaussians.
Algorithm derived in "A fixed-point approach to barycenters in Wasserstein space"
by Álvarez-Esteban et al. (2016) <https://arxiv.org/abs/1511.05355>.
Args:
Ks: `[n, d, d]` batch of covariance matrices, one for each centered Gaussian
in the set whose barycenter we want to compute.
weights: `[n]` batch of weights for each Gaussian.
Returns:
Covariance matrix of the barycenter.
"""
n = len(Ks)
assert n > 1, "Need at least two Gaussians to compute a barycenter"

# Uniform weights by default
if weights is None:
weights = Ks.new_ones(n) / n
else:
assert len(weights) == n, "Need one weight per Gaussian"
weights = weights / weights.sum()

# Bookkeeping variables
loss = torch.inf
tol = torch.finfo(Ks.dtype).eps
weights = weights.view(-1, 1, 1) # Broadcastable to Ks

# Initialize with arithmetic mean of covariance matrices
mu = Ks.mul(weights).sum(dim=0)
trace_avg = mu.trace()

# Begin Álvarez-Esteban et al. fixed-point iteration
for _ in range(max_iter):
sqrt_mu, rsqrt_mu = psd_sqrt_rsqrt(mu)
inner = psd_sqrt(sqrt_mu @ Ks @ sqrt_mu)

# Equation 15 from Álvarez-Esteban et al. (2016)
new_loss = mu.trace() + trace_avg - 2 * inner.mul(weights).sum(dim=0).trace()

# Break if the loss is not decreasing
if loss - new_loss < tol:
break
else:
loss = new_loss

# Equation 7 from Álvarez-Esteban et al. (2016)
T = torch.sum(weights * rsqrt_mu @ inner @ rsqrt_mu, dim=0)
mu = T @ mu @ T.mT

return mu


def ot_distance(K1: Tensor, K2: Tensor) -> Tensor:
"""2-Wasserstein distance between N(0, K1) and N(0, K2)."""
sqrt_K1 = psd_sqrt(K1)
inner = psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1)

# Compute the 2-Wasserstein distance
dist = torch.sqrt(trace(K1) + trace(K2) - 2 * trace(inner))
return dist.squeeze(-1).squeeze(-1)


def ot_map(K1: Tensor, K2: Tensor) -> Tensor:
"""Optimal transport map from N(0, K1) to N(0, K2) in matrix form.
Args:
K1: Covariance matrix of the first Gaussian.
K2: Covariance matrix of the second Gaussian.
Returns:
Unique p.s.d. matrix A such that N(0, A @ K1 @ A.T) = N(0, K2).
"""
sqrt_K1, rsqrt_K1 = psd_sqrt_rsqrt(K1)
return rsqrt_K1 @ psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1) @ rsqrt_K1


def ot_midpoint(K1: Tensor, K2: Tensor, w1: float = 0.5, w2: float = 0.5) -> Tensor:
"""Covariance matrix of the 2-Wasserstein barycenter of N(0, K1) and N(0, K2).
The barycenter of a set of distributions S is the unique distribution mu which
minimizes the mean squared Wasserstein distance from each distribution in S to mu.
Derived in "On Gaussian Wasserstein Barycenters" (Wessel Bruinsma & Gabriel Arpino)
<https://gabrielarpino.github.io/files/wasserstein.pdf>.
Args:
K1: Covariance matrix of the first Gaussian.
K2: Covariance matrix of the second Gaussian.
w1: Weight of the first Gaussian.
w2: Weight of the second Gaussian.
Returns:
Covariance matrix of the barycenter.
"""
sqrt_K1, rsqrt_K1 = psd_sqrt_rsqrt(K1)
product = sqrt_K1 @ psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1) @ rsqrt_K1

return w1 * w1 * K1 + w2 * w2 * K2 + w1 * w2 * (product + product.T)
Loading

0 comments on commit 0d4dbcb

Please sign in to comment.