-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from EleutherAI/quadratic
Quadratic LEACE
- Loading branch information
Showing
8 changed files
with
593 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,22 @@ | ||
from .concept_scrubber import ConceptScrubber | ||
from .groupby import GroupedTensor, groupby | ||
from .leace import ErasureMethod, LeaceEraser, LeaceFitter | ||
from .oracle import OracleEraser, OracleFitter | ||
from .quadratic import QuadraticEraser, QuadraticFitter | ||
from .shrinkage import optimal_linear_shrinkage | ||
from .utils import assert_type | ||
|
||
__all__ = [ | ||
"assert_type", | ||
"groupby", | ||
"optimal_linear_shrinkage", | ||
"ConceptScrubber", | ||
"ErasureMethod", | ||
"GroupedTensor", | ||
"LeaceEraser", | ||
"LeaceFitter", | ||
"OracleEraser", | ||
"OracleFitter", | ||
"ErasureMethod", | ||
"QuadraticEraser", | ||
"QuadraticFitter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from dataclasses import dataclass | ||
from typing import Callable, Iterator | ||
|
||
import torch | ||
from torch import LongTensor, Tensor | ||
|
||
|
||
@dataclass(frozen=True) | ||
class GroupedTensor: | ||
"""A tensor split into groups along a given dimension. | ||
This class contains all the information needed to reconstruct the original tensor, | ||
or to take a list of tensors derived from the groups and coalesce them in such a | ||
way that the original order is restored. | ||
""" | ||
|
||
dim: int | ||
"""Dimension along which the tensor was split.""" | ||
|
||
groups: list[Tensor] | ||
"""List of tensors such that `groups[i]` contains all elements of `x` whose group | ||
label is `labels[i]`.""" | ||
|
||
indices: LongTensor | ||
"""Indices used to sort the original tensor.""" | ||
|
||
labels: list[int] | ||
"""Unique label for each element of `groups`.""" | ||
|
||
def coalesce(self, groups: list[Tensor] | None = None) -> Tensor: | ||
"""Fuse `groups or self.groups` back together, restoring the original order. | ||
This method is most useful when you want to group a tensor, perform an operation | ||
on each group, then combine the results back together. | ||
""" | ||
if groups is None: | ||
groups = self.groups | ||
|
||
# First concatenate the groups back together | ||
fused = torch.cat(groups, dim=self.dim) | ||
|
||
# Invert the permutation to restore the original order | ||
return fused.index_select(self.dim, invert_indices(self.indices)) | ||
|
||
def map(self, fn: Callable[[int, Tensor], Tensor]) -> "GroupedTensor": | ||
"""Apply `fn` to each group & return a new `GroupedTensor` with the results.""" | ||
results = [fn(label, group) for label, group in zip(self.labels, self.groups)] | ||
return GroupedTensor(self.dim, results, self.indices, self.labels) | ||
|
||
def __iter__(self) -> Iterator[tuple[int, Tensor]]: | ||
"""Iterate over the groups and their labels.""" | ||
for label, group in zip(self.labels, self.groups): | ||
yield label, group | ||
|
||
|
||
def groupby( | ||
x: Tensor, key: Tensor, dim: int = 0, *, stable: bool = False | ||
) -> GroupedTensor: | ||
"""Efficiently split `x` into groups along `dim` according to `key`. | ||
This function is intended to mimic the behavior of `itertools.groupby`, but for | ||
PyTorch tensors. Under the hood, we sort `x` by `key` once, then return views | ||
onto the sorted tensor in order to minimize the number of memcpy and equality | ||
checking operations performed. | ||
By necessity this operation performs a host-device sync since we need to know | ||
the number of groups and their sizes in order to create a view for each. | ||
Args: | ||
x: Tensor to split into groups. | ||
key: Tensor of group labels. | ||
dim: Dimension along which to split `x`. | ||
stable: If `True`, use a stable sorting algorithm. This is slower but ensures | ||
that the order of elements within each group is preserved. | ||
Returns: | ||
A `GroupedTensor` containing the groups, sorting indices, and labels. | ||
""" | ||
assert key.dtype == torch.int64, "`key` must be int64" | ||
assert key.ndim == 1, "`key` must be 1D" | ||
|
||
key, indices = key.sort(stable=stable) | ||
labels, counts = key.unique_consecutive(return_counts=True) | ||
|
||
# Sort `x` by `key` along `dim` | ||
x = x.index_select(dim, indices) | ||
groups = x.split(counts.tolist(), dim=dim) | ||
|
||
return GroupedTensor(dim, groups, indices, labels.tolist()) | ||
|
||
|
||
@torch.jit.script | ||
def invert_indices(indices: Tensor) -> Tensor: | ||
"""Efficiently invert the permutation represented by `indices`. | ||
Example: | ||
>>> indices = torch.tensor([2, 0, 1]) | ||
>>> invert_indices(indices) | ||
tensor([1, 2, 0]) | ||
""" | ||
# Create an empty tensor to hold the reverse permutation | ||
reverse_indices = torch.empty_like(indices) | ||
|
||
# Scatter the indices to reverse the permutation | ||
reverse_indices.scatter_(0, indices, torch.arange(len(indices))) | ||
|
||
return reverse_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
from .shrinkage import trace | ||
|
||
|
||
def is_positive_definite(A: Tensor) -> Tensor: | ||
"""Efficiently check if `A` is p.d. by attempting Cholesky decomposition.""" | ||
return torch.linalg.cholesky_ex(A).info.eq(0) | ||
|
||
|
||
@torch.jit.script | ||
def psd_sqrt(A: Tensor) -> Tensor: | ||
"""Compute the unique p.s.d. square root of a positive semidefinite matrix.""" | ||
L, U = torch.linalg.eigh(A) | ||
L = L[..., None, :].clamp_min(0.0) | ||
return U * L.sqrt() @ U.mT | ||
|
||
|
||
def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]: | ||
"""Efficiently compute both the p.s.d. sqrt & pinv sqrt of p.s.d. matrix `A`.""" | ||
L, U = torch.linalg.eigh(A) | ||
L = L[..., None, :].clamp_min(0.0) | ||
|
||
# Square root is easy | ||
sqrt = U * L.sqrt() @ U.mT | ||
|
||
# We actually compute the pseudo-inverse here for numerical stability. | ||
# Use the same heuristic as `torch.linalg.pinv` to determine the tolerance. | ||
thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps | ||
rsqrt = U * L.rsqrt().where(L > thresh, 0.0) @ U.mT | ||
|
||
return sqrt, rsqrt | ||
|
||
|
||
def ot_barycenter( | ||
Ks: Tensor, weights: Tensor | None = None, *, max_iter: int = 100 | ||
) -> Tensor: | ||
"""Fixed-point iteration for the 2-Wasserstein barycenter of a set of Gaussians. | ||
Algorithm derived in "A fixed-point approach to barycenters in Wasserstein space" | ||
by Álvarez-Esteban et al. (2016) <https://arxiv.org/abs/1511.05355>. | ||
Args: | ||
Ks: `[n, d, d]` batch of covariance matrices, one for each centered Gaussian | ||
in the set whose barycenter we want to compute. | ||
weights: `[n]` batch of weights for each Gaussian. | ||
Returns: | ||
Covariance matrix of the barycenter. | ||
""" | ||
n = len(Ks) | ||
assert n > 1, "Need at least two Gaussians to compute a barycenter" | ||
|
||
# Uniform weights by default | ||
if weights is None: | ||
weights = Ks.new_ones(n) / n | ||
else: | ||
assert len(weights) == n, "Need one weight per Gaussian" | ||
weights = weights / weights.sum() | ||
|
||
# Bookkeeping variables | ||
loss = torch.inf | ||
tol = torch.finfo(Ks.dtype).eps | ||
weights = weights.view(-1, 1, 1) # Broadcastable to Ks | ||
|
||
# Initialize with arithmetic mean of covariance matrices | ||
mu = Ks.mul(weights).sum(dim=0) | ||
trace_avg = mu.trace() | ||
|
||
# Begin Álvarez-Esteban et al. fixed-point iteration | ||
for _ in range(max_iter): | ||
sqrt_mu, rsqrt_mu = psd_sqrt_rsqrt(mu) | ||
inner = psd_sqrt(sqrt_mu @ Ks @ sqrt_mu) | ||
|
||
# Equation 15 from Álvarez-Esteban et al. (2016) | ||
new_loss = mu.trace() + trace_avg - 2 * inner.mul(weights).sum(dim=0).trace() | ||
|
||
# Break if the loss is not decreasing | ||
if loss - new_loss < tol: | ||
break | ||
else: | ||
loss = new_loss | ||
|
||
# Equation 7 from Álvarez-Esteban et al. (2016) | ||
T = torch.sum(weights * rsqrt_mu @ inner @ rsqrt_mu, dim=0) | ||
mu = T @ mu @ T.mT | ||
|
||
return mu | ||
|
||
|
||
def ot_distance(K1: Tensor, K2: Tensor) -> Tensor: | ||
"""2-Wasserstein distance between N(0, K1) and N(0, K2).""" | ||
sqrt_K1 = psd_sqrt(K1) | ||
inner = psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1) | ||
|
||
# Compute the 2-Wasserstein distance | ||
dist = torch.sqrt(trace(K1) + trace(K2) - 2 * trace(inner)) | ||
return dist.squeeze(-1).squeeze(-1) | ||
|
||
|
||
def ot_map(K1: Tensor, K2: Tensor) -> Tensor: | ||
"""Optimal transport map from N(0, K1) to N(0, K2) in matrix form. | ||
Args: | ||
K1: Covariance matrix of the first Gaussian. | ||
K2: Covariance matrix of the second Gaussian. | ||
Returns: | ||
Unique p.s.d. matrix A such that N(0, A @ K1 @ A.T) = N(0, K2). | ||
""" | ||
sqrt_K1, rsqrt_K1 = psd_sqrt_rsqrt(K1) | ||
return rsqrt_K1 @ psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1) @ rsqrt_K1 | ||
|
||
|
||
def ot_midpoint(K1: Tensor, K2: Tensor, w1: float = 0.5, w2: float = 0.5) -> Tensor: | ||
"""Covariance matrix of the 2-Wasserstein barycenter of N(0, K1) and N(0, K2). | ||
The barycenter of a set of distributions S is the unique distribution mu which | ||
minimizes the mean squared Wasserstein distance from each distribution in S to mu. | ||
Derived in "On Gaussian Wasserstein Barycenters" (Wessel Bruinsma & Gabriel Arpino) | ||
<https://gabrielarpino.github.io/files/wasserstein.pdf>. | ||
Args: | ||
K1: Covariance matrix of the first Gaussian. | ||
K2: Covariance matrix of the second Gaussian. | ||
w1: Weight of the first Gaussian. | ||
w2: Weight of the second Gaussian. | ||
Returns: | ||
Covariance matrix of the barycenter. | ||
""" | ||
sqrt_K1, rsqrt_K1 = psd_sqrt_rsqrt(K1) | ||
product = sqrt_K1 @ psd_sqrt(sqrt_K1 @ K2 @ sqrt_K1) @ rsqrt_K1 | ||
|
||
return w1 * w1 * K1 + w2 * w2 * K2 + w1 * w2 * (product + product.T) |
Oops, something went wrong.