Skip to content

Commit

Permalink
Merge pull request #7 from EleutherAI/where-fix
Browse files Browse the repository at this point in the history
Use torch.where instead of Tensor.where for PyTorch <2.0
  • Loading branch information
norabelrose committed Sep 7, 2023
2 parents 0d4dbcb + 5866ce4 commit 9689ef5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions concept_erasure/leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def eraser(self) -> LeaceEraser:
# Assuming PSD; account for numerical error
L.clamp_min_(0.0)

W = V * L.rsqrt().where(mask, 0.0) @ V.mT
W_inv = V * L.sqrt().where(mask, 0.0) @ V.mT
W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mT
W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mT
else:
W, W_inv = eye, eye

Expand Down
2 changes: 1 addition & 1 deletion concept_erasure/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]:
# We actually compute the pseudo-inverse here for numerical stability.
# Use the same heuristic as `torch.linalg.pinv` to determine the tolerance.
thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps
rsqrt = U * L.rsqrt().where(L > thresh, 0.0) @ U.mT
rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mT

return sqrt, rsqrt

Expand Down

0 comments on commit 9689ef5

Please sign in to comment.