-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
moving solvers related code to solvers/* directory, improving fista and nnls_hals #550
base: main
Are you sure you want to change the base?
Conversation
Not sure why the pytorch test for constrained_parafac is returning an error, I did not change the contents of the code there... seems to be a type mismatch between floats and integers but I have no idea why it happens. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #550 +/- ##
==========================================
- Coverage 87.30% 80.10% -7.20%
==========================================
Files 125 132 +7
Lines 7868 8425 +557
==========================================
- Hits 6869 6749 -120
- Misses 999 1676 +677 ☔ View full report in Codecov by Sentry. |
@aarmey @JeanKossaifi @MarieRoald @yngvem if you can find some time to review this PR at your own pace that would be fantastic :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Here are some minor suggestions.
@cohenjer the PyTorch test failure may have to do with a typing issue here, too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @aarmey; apart from the pytorch issue which I have yet to solve, I accounted for your comments.
Sorry for the weird duplication of answers.
@cohenjer by the way, |
There is a remaining issue with Tensorflow solve in the backend, when the inputs are tf.Variable (which is what we use for the factors in the CP class). Indeed, tl.solve calls fac.ndim, which is undefined for tf.Variable. def solve(lhs, rhs):
squeeze = False
try:
if rhs.ndim == 1:
squeeze = [-1]
rhs = tf.reshape(rhs, (-1, 1))
res = tf.linalg.solve(lhs, rhs)
if squeeze:
res = tf.squeeze(res, squeeze)
return res
except AttributeError:
# Variable inputs
res = tf.linalg.solve(lhs, rhs)
return res |
Interesting. |
Thanks for the tip (using |
This is a copy of the part of #542 related to solvers.
What this PR solves:
What this PR does:
(I left the old proximal file with the modified nnls_hals and fista solvers to show the changes, we can remove it before merging.)
solvers/penalizations.py
with a utility function to process input ridge/sparsity regularization (transform 1d input to list of correct length, avoid no regularization on only some modes which makes the factorization ill-posed).Possible improvements:
Note: If this PR is merged, I will proceed with the nonnegative algorithms improvements in #542.