-
Notifications
You must be signed in to change notification settings - Fork 947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pinv #875
Pinv #875
Conversation
@adhulipa I think this implementation is the wrong way to go about it. Using SVD to compute the pseudo inverse means we don't need a primitive and kernels etc. It is just an op that can reside in the Basically something like the following (of the top of my head so ymmv) def pinv(x):
U, S, V = mx.linalg.svd(x)
return (V[:len(x)].T * 1/S) @ U.T |
Ahh I see! I didn’t think about that. Thanks for the review @angeloskath I suppose we could modify this PR to merge in a Python form one as a first step and then investigate whether a custom kernel is necessary. Would you recommend such a direction? |
The op should be in C++ and then do a binding (we try to keep the C++ and Python APIs reasonably consistent). I think the Python impl from @angeloskath is just intended as pseudo-code for that. |
Ah yes that makes sense. I should add the Python api that matches the cpp api for pinv(). I haven’t gotten around to it. Thank you for taking a look folks! |
1837c9a
to
1b513c7
Compare
I made a few updates. Still gotta figure out how to fix the cpp op where svd(A) returns u, s, vt where u has same dims as A (when rectangular). This makes the matmul incompatible. I have a path to green where I need to tweak u to match expected end-shape. (I’m positive the SVD approach works accurately because I validated it in Python api mlx; and few other langs such as matlab to be certain) Also can use the PyTorch impl as a reference https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/aten/src/ATen/native/LinearAlgebra.cpp#L532 just need to get around to it in due time |
mlx/backend/common/pseudoinverse.cpp
Outdated
// v* 5x4 | ||
auto inner = transpose(matmul(s_plus, u)); | ||
auto result = matmul(v, inner); | ||
copy(result, pinv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's bug here (or it's call site where the pinv
array is allocated) where the size/shape of the pinv array is not correctly setup
terminated by signal SIGSEGV (Address boundary error)
@adhulipa are you planning to come back to this? |
@adhulipa are you planning to return to this one? |
Hi @awni yes I will update this one. I am running into an issue where I haven't figured out how to allocate the rectangular array for the output/result array before passing it off to the PINV function. Apologies for the delay; other priorities took precedence lately. I think I should be able to dedicate a few hours this weekend -- likely 4-8 hours on 5/11 |
@awni do you think it’s better to close this PR and reopen against a newer mainline commit? Happy to do so if it helps keep your PR todo list clean |
Its more up to you. If you plan to work on it in the near future then you can keep it open (or start a new one if you prefer). If not, I would close it. |
I'll keep it open for now. I'll close and re-open if it gets too far behind significantly -- for now these changes are additive; so that's not a risk. It just needs a bit of polish/bugfixing. |
Made some progress. Need to fix a few more things. |
I am suspecting there's something I need to figure out with how im using Im seeing
Which seems contradictory to the MLX svd doc
Of course, I ack MLX mimics the NumPy API and NumPy indeed also produces a similar result. But it looks like they have support for a
(Fwiw, without full_matrics=False, error is same as mlx)
|
I think you can just do something like this: U, S, V = mx.linalg.svd(A)
K = min(A.shape[0], A.shape[1])
Atilde = (U[:, :K] * S) @ V[:K, :] We could add the slicing as an option like Numpy if it's useful. Also I would recommend you rebase before making further progress to make it easier to resolve conflicts. |
Ah thanks Awni! Will use that |
Small update: Got a local build that correctly computes pinv in most of the tests. Cleaning up some things and polishing up the code.
Turns out I was incorrectly relying on the computation array graph API instead of computing the actual matrix products (D'oh!). Now I have some code locally using lapack's mm func (such as sgemm) to compute the final pinv product. Will update this PR soon |
Updated the PR. This PR is in a good enough shape for a review from @awni and other MLX folks. Thanks! Perhaps there's one more thing to check (on my part) in the python tests. Will look into it. But in the meantime, this PR is still good for a review. |
Drats. I have another bug to fix. I updated the tests to catch it. Will look into and fix. Essentially, long rectangular matrices have a matmul dim mismatch -- which means I have made an error in the m, n, k calculations and/or slice selections or U/Vt |
Fixed the bug for rectangular matrices where M > N 🎉 |
This PR is ready for a review from Awni, Angelos and other MLX folks. Thanks! |
Hi @adhulipa . I think there shouldn't be a primitive for this operation. It can really just be an op in the linalg namespace. |
Hi @angeloskath ohh I see. I think I may have misinterpreted something in the thread here then. Particularly what @awni shared after you (@angeloskath) shared that comment earlier.
Is it accurate to say that you meant this should in
|
Actually, @angeloskath do you mean to say that we don't need a primitive; but all the logic of calling It seems like the recommendation here is to keep the core logic intact but just not make this a primitive. Am I understanding that correctly? |
I think im starting to understand the motivation behind the c++ op sans primitive recommendation from Angelos. Pardon the roundabout way I needed to understand this 😅 The following change in linalg.cpp does pass the tests. Just checking a few more things before I can publish a new commit.
|
@angeloskath @awni -- question: do you folks feel like this is in a good shape for a review? Of course, no rush from my pov; just thought I'd check. |
hi @angeloskath @awni -- im curious if you think this PR is good for merge or review? Or perhaps there's another way to build this functionality? I am thankful for all the feedback so far and pointers on the road forward : ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Sorry for taking some time to review this, it is close but it still needs some changes before merging.
mlx/linalg.cpp
Outdated
std::ostringstream msg; | ||
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array " | ||
"with " | ||
<< a.ndim() << " dimensions."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would change the above to
<< "with " << a.ndim() << " dimensions.";
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlx/linalg.cpp
Outdated
const auto m = a.shape(-2); | ||
const auto n = a.shape(-1); | ||
const auto k = std::min(m, n); | ||
const auto rank = a.ndim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank
is not used anywhere .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlx/linalg.cpp
Outdated
array pinv(pinv_shape, a.dtype(), nullptr, {}); | ||
const auto U_slice = slice(U, {0, 0}, {m, k}); | ||
const auto Vt_slice = slice(Vt, {0, 0}, {k, n}); | ||
return matmul(matmul(transpose(Vt_slice), diag(1.0 / S)), transpose(U_slice)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for two matmuls and a diag
this can be done with a division. Also the streams should be appropriately used.
std::vector<int> starts(a.ndim(), 0);
std::vector<int> ends = a.shape();
int i = a.ndim()-2;
int j = a.ndim()-1;
ends[i] = n;
ends[j] = k;
array U_slice = slice(U, starts, ends, s);
ends[i] = k;
ends[j] = m;
array V_slice = slice(Vt, starts, ends, s);
return swapaxes(matmul(divide(U_slice, expand_dims(S, -2, s), s), V_slice, s), -1, -2, s);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated ✅ - in new commit ac7e0cb; diff here
Btw, I think your message above mixed up the ends[]
dims for U_slice
vs V_slice
leading to their dims being {n, k}
and {k, m}
respectively. But IIUC they should instead be U_slice -> {m, k}
and V_slice -> {k, n}
. The exact code you shared above caused a test error and after investigating I updated the code to use ensure dims of U_slice
is {m, k}
and V_slice
is {k, n}
. I verified the results comparing to numpy/matlab. But please double check my changes. Thanks!
mlx/linalg.cpp
Outdated
} else { | ||
pinv_shape = {k, m}; | ||
} | ||
array pinv(pinv_shape, a.dtype(), nullptr, {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to initialize a pinv
array at all this array is also not used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Angelos! No need to apologize re: delay : ) -- I appreciate the review + feedback. I will address and update the PR this week. |
Looks like a test failed. I will investigate
|
It looks like the failure is a matter of precision For the square matrix case, it looks like 1e-5 is a passing tolerance limit whereas the existing 1e-6 level is causing a failure
Similarly, for a 2x3x3 matrix, it seems to be within 1e-3 but not within 1e-6
|
@angeloskath hey there Angelos, curious if you or others on MLX could come back to this PR sometime? Thanks! |
Proposed changes
Add Moore-Penrose Pseudo Inverse function. Inspired by the recent PRs from @nicolov in adding svd and inv, this PR adds the
pinv
primitiveTests
Ran some tests locally and included them in PR
Re-Tested, and everything looks good
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes