Skip to content

Commit

Permalink
Fixed ruff formatting issue
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Jan 25, 2024
1 parent 2844e7b commit 2cff73e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions concept_erasure/shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def optimal_linear_shrinkage(

trace_S = trace(S_n)

# Since sigma0 is I * tr(S_n) / p, its squared Frobenius norm is just tr(S_n) ** 2 / p.
sigma0_norm_sq = trace_S ** 2 / p
# Since sigma0 is I * tr(S_n) / p, its squared Frobenius norm is tr(S_n) ** 2 / p.
sigma0_norm_sq = trace_S**2 / p
S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2

prod_trace = sigma0_norm_sq #torch.linalg.diagonal(S_n) # trace(S_n @ sigma0)
prod_trace = sigma0_norm_sq
top = trace_S * trace_S.conj() * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj()

Expand Down

0 comments on commit 2cff73e

Please sign in to comment.