Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More stable and efficient Mahalanobis distance #40

Open
ejnnr opened this issue May 26, 2024 · 8 comments
Open

More stable and efficient Mahalanobis distance #40

ejnnr opened this issue May 26, 2024 · 8 comments

Comments

@ejnnr
Copy link
Owner

ejnnr commented May 26, 2024

We currently use pinv, should probably use https://pytorch.org/docs/stable/generated/torch.linalg.lstsq.html instead (or maybe some third thing). Doesn't seem like a big issue so far though

@VRehnberg
Copy link
Collaborator

VRehnberg commented May 28, 2024

I ran a quick benchmark to compare speed

act = torch.randn(128, n)
cov = act.T @ act

# lstsqmm
t = timeit.timeit(lambda: torch.einsum("bi,ib->b", act, torch.linalg.lstsq(cov, act.T, rcond=1e-5).solution))
# pinv
t = timeit.timeit(lambda: torch.linalg.pinv(cov, rcond=1e-5))
# pinvmm (not counting pinv)
inv_cov = torch.linalg.pinv(cov, rcond=1e-5)
t = timeit.timeit(lambda: torch.einsum("bi,ij,bj->b", act, inv_cov, act))

n is number of activations and t is time in seconds
perf_linalg

So in terms of speed it looks like there could be an order of magnitude speed-up with least squares (though the trend looks a bit weird, so not sure what is going on). On the other hand pinv only needs to be calculated once in training while lstsq is used at inference time.

Memory usage might be another thing that becomes better with lstsq but didn't look at that. On the other hand if I remember correctly you'll run into 32bit overflows in the indexing before memory usage becomes entirely unreasonable.

@VRehnberg
Copy link
Collaborator

For SpectralSignatures we might want to consider lobpcg for the top singular vectors. Huge speed-up and memory reduction, but not sure how stable it is.

@ejnnr
Copy link
Owner Author

ejnnr commented May 28, 2024

Thanks for this! Given that this doesn't seem to be a bottleneck for now, probably shouldn't spend too much time on it, but I'm getting slightly nerdsniped. If we assume the covariance matrix is full rank, then I think torch.cholesky_solve is probably much better than lstsq (i.e. we'd factorize the covariance matrix once using torch.linalg.cholesky, and then solving given that factorization is very cheap). Seems much faster based on some quick experiments (both cholesky and each subsequent cholesky_solve take time comparable to the matrix multiply that pinv needs to do on each batch).

I think allowing non-full rank covariance matrices (typically from fewer samples than dimensions of the activation) is kind of nice, but I'm not sure anyway whether using a pseudoinverse/lstsq is a reasonable way to handle that case. It means that if there's any variation at test time along axes with zero singular value, they'll just be ignored for the distance which seems a bit extreme. Not sure whether there's anything better though.

If we want to keep the current behavior, we could keep track how many samples were used to compute the covariance matrix, and then use cholesky iff n_samples > d_hidden

@ejnnr
Copy link
Owner Author

ejnnr commented May 28, 2024

Oh, this is of course even better than cholesky_solve. We can use solve_triangular to do the forward substitution

@VRehnberg
Copy link
Collaborator

Oh, that question title is spot on. Shows how much research I've done... ^^'

@ejnnr
Copy link
Owner Author

ejnnr commented May 28, 2024

Note that lstsq on CUDA apparently only supports using QR factorization, which requires the matrix to be full-rank. pinv uses eigh behind the scenes (if we pass hermitian=True), which is why it's so slow, but which also means it works for arbitrary matrices.

@ejnnr
Copy link
Owner Author

ejnnr commented May 28, 2024

Possible compromise between simplicity and speed/stability: use the cholesky method from above by default, but have a fallback using pinv if the user passes in singular=True or something like that. (So we don't have to manually track how many samples were used or estimate the rank, but users can still use Mahalanobis with degenerate matrices if they really want to.) Like I said though, not super important for now

@VRehnberg
Copy link
Collaborator

Sounds like a good compromise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants