Skip to content

Commit

Permalink
Bugfix for ot_barycenter with complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Sep 19, 2023
1 parent bace9ac commit 644ec94
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion concept_erasure/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def ot_barycenter(
new_loss = mu.trace() + trace_avg - 2 * inner.mul(weights).sum(dim=0).trace()

# Break if the loss is not decreasing
if loss - new_loss < tol:
if loss.real - new_loss.real < tol:
break
else:
loss = new_loss
Expand Down

0 comments on commit 644ec94

Please sign in to comment.