Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Add Pinball datafit #134

Merged
merged 31 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
413ef54
remove sqrt n_samples
Badr-MOUFAD Nov 30, 2022
2ef5eb7
update unittest
Badr-MOUFAD Nov 30, 2022
5c0bedc
info comment statsmodels
Badr-MOUFAD Dec 1, 2022
ca6ece7
add prox subdiff to sqrt df
Badr-MOUFAD Dec 1, 2022
a6303e5
implement ``PDCD_WS``
Badr-MOUFAD Dec 1, 2022
e8fcee3
r sqrt_n from CB
Badr-MOUFAD Dec 1, 2022
339e98f
Merge branch 'r-sqrt-n' of https://github.com/Badr-MOUFAD/skglm into …
Badr-MOUFAD Dec 1, 2022
19a0ea9
bug w and subdiff
Badr-MOUFAD Dec 1, 2022
e01451d
unittest sqrt
Badr-MOUFAD Dec 1, 2022
dd36b88
add docs
Badr-MOUFAD Dec 1, 2022
523419b
fix docs SqrtQuadratic
Badr-MOUFAD Dec 1, 2022
71de179
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Dec 2, 2022
63a547b
subdiff --> fixed_point
Badr-MOUFAD Dec 4, 2022
f78d17d
efficient prox conjugate && fix tests
Badr-MOUFAD Dec 5, 2022
d0ae3a4
remove go
Badr-MOUFAD Dec 5, 2022
ad36485
MM remarks
Badr-MOUFAD Dec 5, 2022
f60bd59
fix test && clean ups
Badr-MOUFAD Dec 5, 2022
5a5f1ba
MM round 2 remarks
Badr-MOUFAD Dec 5, 2022
4f27c56
CI Trigger
Badr-MOUFAD Dec 5, 2022
fe45faa
implement pinball
Badr-MOUFAD Dec 6, 2022
3ce886f
unittest
Badr-MOUFAD Dec 6, 2022
6928502
fix pinball value && ST step
Badr-MOUFAD Dec 6, 2022
1271288
more unittest
Badr-MOUFAD Dec 6, 2022
bd1984a
fix bug prox pinball
Badr-MOUFAD Dec 6, 2022
36100c7
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Dec 8, 2022
1a03c60
MM remarks
Badr-MOUFAD Dec 8, 2022
4b3ea45
Update skglm/experimental/quantile_regression.py
mathurinm Dec 8, 2022
9cf2216
pinball expression
Badr-MOUFAD Dec 8, 2022
626b71d
Merge branch 'pinball-df' of https://github.com/Badr-MOUFAD/skglm int…
Badr-MOUFAD Dec 8, 2022
8e93720
sqrt --> pinball
Badr-MOUFAD Dec 8, 2022
0a247f0
quantile --> quantile_level
Badr-MOUFAD Dec 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
subdiff --> fixed_point
  • Loading branch information
Badr-MOUFAD committed Dec 4, 2022
commit 63a547b8065f255a4a61baee0d2bac4fb9f869a7
45 changes: 35 additions & 10 deletions skglm/experimental/pdcd_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class PDCD_WS:
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
"""Primal-Dual Coordinate Descent solver with working sets.

Solver inspired by [1] that uses working sets.
Solver inspired by [1] that uses working sets [2].

Parameters
----------
Expand Down Expand Up @@ -83,13 +83,17 @@ def solve(self, X, y, datafit_, penalty_):
z_bar = self.dual_init.copy()

p_objs = []
stop_crit = 0.
all_features = np.arange(n_features)

for iter in range(self.max_iter):

# check convergence
opts_primal = penalty.subdiff_distance(w, X.T @ z, all_features)
opt_dual = datafit.subdiff_distance(Xw, z, y)
opts_primal = _scores_primal(
X, w, z, penalty, primal_steps, all_features)

opt_dual = _score_dual(
y, z, Xw, datafit, dual_step)

stop_crit = max(
max(opts_primal),
Expand Down Expand Up @@ -126,7 +130,7 @@ def solve(self, X, y, datafit_, penalty_):
warnings.warn(
f"PDCD_WS did not converge for tol={self.tol:.3e} "
f"and max_iter={self.max_iter}.\n"
"Considering increasing `max_iter` or decreasing `tol`.",
"Considering increasing `max_iter` or `tol`.",
category=ConvergenceWarning
)

Expand All @@ -137,16 +141,15 @@ def solve(self, X, y, datafit_, penalty_):
def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
primal_steps, dual_step, ws, max_epochs, tol_in):
n_features = X.shape[1]
past_pseudo_grad = np.zeros(len(ws))

for epoch in range(max_epochs):

for idx, j in enumerate(ws):
for j in ws:
# update primal
old_w_j = w[j]
past_pseudo_grad[idx] = X[:, j] @ (2 * z_bar - z)
pseudo_grad = X[:, j] @ (2 * z_bar - z)
w[j] = penalty.prox_1d(
old_w_j - primal_steps[j] * past_pseudo_grad[idx],
old_w_j - primal_steps[j] * pseudo_grad,
primal_steps[j], j)

# keep Xw syncr with X @ w
Expand All @@ -161,8 +164,11 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,

# check convergence
if epoch % 10 == 0:
opts_primal_in = penalty.subdiff_distance(w, past_pseudo_grad, ws)
opt_dual_in = datafit.subdiff_distance(Xw, z, y)
opts_primal_in = _scores_primal(
X, w, z, penalty, primal_steps, ws)

opt_dual_in = _score_dual(
y, z, Xw, datafit, dual_step)

stop_crit_in = max(
max(opts_primal_in),
Expand Down Expand Up @@ -192,3 +198,22 @@ def _validate_init(datafit_, penalty_):
compiled_penalty = compiled_clone(penalty_)

return compiled_datafit, compiled_penalty


@njit
def _scores_primal(X, w, z, penalty, primal_steps, ws):
scores_ws = np.zeros(len(ws))

for idx, j in enumerate(ws):
next_w_j = penalty.prox_1d(w[j] - primal_steps[j] * X[:, j] @ z,
primal_steps[j], j)
scores_ws[idx] = abs(w[j] - next_w_j)

return scores_ws


@njit
def _score_dual(y, z, Xw, datafit, dual_step):
next_z = datafit.prox_conjugate(z + dual_step * Xw,
dual_step, y)
return norm(z - next_z, ord=np.inf)
7 changes: 7 additions & 0 deletions skglm/experimental/tests/test_sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,11 @@ def test_PDCD_WS(with_dual_init):


if __name__ == '__main__':
n_samples, n_features = 50, 1000
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)

alpha_max = norm(X.T @ y, ord=np.inf) / norm(y)
alpha = alpha_max / 10

PDCD_WS(verbose=1, dual_init=y / norm(y)).solve(X, y, SqrtQuadratic(), L1(alpha))
pass