Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Add Pinball datafit #134

Merged
merged 31 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
413ef54
remove sqrt n_samples
Badr-MOUFAD Nov 30, 2022
2ef5eb7
update unittest
Badr-MOUFAD Nov 30, 2022
5c0bedc
info comment statsmodels
Badr-MOUFAD Dec 1, 2022
ca6ece7
add prox subdiff to sqrt df
Badr-MOUFAD Dec 1, 2022
a6303e5
implement ``PDCD_WS``
Badr-MOUFAD Dec 1, 2022
e8fcee3
r sqrt_n from CB
Badr-MOUFAD Dec 1, 2022
339e98f
Merge branch 'r-sqrt-n' of https://github.com/Badr-MOUFAD/skglm into …
Badr-MOUFAD Dec 1, 2022
19a0ea9
bug w and subdiff
Badr-MOUFAD Dec 1, 2022
e01451d
unittest sqrt
Badr-MOUFAD Dec 1, 2022
dd36b88
add docs
Badr-MOUFAD Dec 1, 2022
523419b
fix docs SqrtQuadratic
Badr-MOUFAD Dec 1, 2022
71de179
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Dec 2, 2022
63a547b
subdiff --> fixed_point
Badr-MOUFAD Dec 4, 2022
f78d17d
efficient prox conjugate && fix tests
Badr-MOUFAD Dec 5, 2022
d0ae3a4
remove go
Badr-MOUFAD Dec 5, 2022
ad36485
MM remarks
Badr-MOUFAD Dec 5, 2022
f60bd59
fix test && clean ups
Badr-MOUFAD Dec 5, 2022
5a5f1ba
MM round 2 remarks
Badr-MOUFAD Dec 5, 2022
4f27c56
CI Trigger
Badr-MOUFAD Dec 5, 2022
fe45faa
implement pinball
Badr-MOUFAD Dec 6, 2022
3ce886f
unittest
Badr-MOUFAD Dec 6, 2022
6928502
fix pinball value && ST step
Badr-MOUFAD Dec 6, 2022
1271288
more unittest
Badr-MOUFAD Dec 6, 2022
bd1984a
fix bug prox pinball
Badr-MOUFAD Dec 6, 2022
36100c7
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Dec 8, 2022
1a03c60
MM remarks
Badr-MOUFAD Dec 8, 2022
4b3ea45
Update skglm/experimental/quantile_regression.py
mathurinm Dec 8, 2022
9cf2216
pinball expression
Badr-MOUFAD Dec 8, 2022
626b71d
Merge branch 'pinball-df' of https://github.com/Badr-MOUFAD/skglm int…
Badr-MOUFAD Dec 8, 2022
8e93720
sqrt --> pinball
Badr-MOUFAD Dec 8, 2022
0a247f0
quantile --> quantile_level
Badr-MOUFAD Dec 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions skglm/experimental/pdcd_ws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import warnings

import numpy as np
from numpy.linalg import norm
from scipy.sparse import issparse

from numba import njit
from skglm.utils.jit_compilation import compiled_clone
from sklearn.exceptions import ConvergenceWarning


class PDCD_WS:
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
r"""Primal-Dual Coordinate Descent solver with working sets.

It solves::

\min_w F(Xw) + G(w)

using a primal-dual method on the saddle point problem::

\min_w \max_z <Xw, z> + G(w) - F^*(z)

where :math:`F` is the datafit term (:math:`F^*` its Fenchel conjugate)
and :math:`G` is the penalty term.

The datafit is required to be convex and proximable. Also, the penalty
is required to be convex, separable, and proximable.

The solver is an adaptation of algorithm [1] to working sets [2].
The working sets are built using a fixed point distance strategy
where each feature is assigned a score based how much its coefficient varies
when performing a primal update::

\text{score}_j = \abs{w_j - prox_{\tau_j, G_j}(w_j - \tau_j <X_j, z>)}

where :maths:`\tau_j` is the primal step associated with the j-th feature.

Parameters
----------
max_iter : int, optional
The maximum number of iterations or equivalently the
the maximum number of solved subproblems.

max_epochs : int, optional
Maximum number of primal CD epochs on each subproblem.

dual_init : array, shape (n_samples,) default None
The initialization of dual variables.
If None, they are initialized as the 0 vector ``np.zeros(n_samples)``.

p0 : int, optional
First working set size.

tol : float, optional
The tolerance for the optimization.

verbose : bool or int, default False
Amount of verbosity. 0/False is silent.

References
----------
.. [1] Olivier Fercoq and Pascal Bianchi,
"A Coordinate-Descent Primal-Dual Algorithm with Large Step Size and Possibly
Nonseparable Functions", SIAM Journal on Optimization, 2020,
https://epubs.siam.org/doi/10.1137/18M1168480,
code: https://github.com/Badr-MOUFAD/Fercoq-Bianchi-solver

.. [2] Bertrand, Q. and Klopfenstein, Q. and Bannier, P.-A. and Gidel, G.
and Massias, M.
"Beyond L1: Faster and Better Sparse Models with skglm", 2022
https://arxiv.org/abs/2204.07826
"""

def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
p0=100, tol=1e-6, verbose=False):
self.max_iter = max_iter
self.max_epochs = max_epochs
self.dual_init = dual_init
self.p0 = p0
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None):
if issparse(X):
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")

datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_)
n_samples, n_features = X.shape

# init steps
# Despite violating the conditions mentioned in [1]
# this choice of steps yield in practice a convergent algorithm
# with better speed of convergence
dual_step = 1 / norm(X, ord=2)
primal_steps = 1 / norm(X, axis=0, ord=2)

# primal vars
w = np.zeros(n_features) if w_init is None else w_init
Xw = np.zeros(n_samples) if Xw_init is None else Xw_init

# dual vars
if self.dual_init is None:
z = np.zeros(n_samples)
z_bar = np.zeros(n_samples)
else:
z = self.dual_init.copy()
z_bar = self.dual_init.copy()

p_objs = []
stop_crit = 0.
all_features = np.arange(n_features)

for iteration in range(self.max_iter):

# check convergence
opts_primal = _scores_primal(
X, w, z, penalty, primal_steps, all_features)

opt_dual = _score_dual(
y, z, Xw, datafit, dual_step)

stop_crit = max(
max(opts_primal),
opt_dual
)

if self.verbose:
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
print(
f"Iteration {iteration+1}: {current_p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}")

if stop_crit <= self.tol:
break

# build ws
gsupp_size = (w != 0).sum()
ws_size = max(min(self.p0, n_features),
min(n_features, 2 * gsupp_size))

# similar to np.argsort()[-ws_size:] but without full sort
ws = np.argpartition(opts_primal, -ws_size)[-ws_size:]

# solve sub problem
# inplace update of w, Xw, z, z_bar
PDCD_WS._solve_subproblem(
y, X, w, Xw, z, z_bar, datafit, penalty,
primal_steps, dual_step, ws, self.max_epochs, tol_in=0.3*stop_crit)

current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
p_objs.append(current_p_obj)
else:
warnings.warn(
f"PDCD_WS did not converge for tol={self.tol:.3e} "
f"and max_iter={self.max_iter}.\n"
"Considering increasing `max_iter` or `tol`.",
category=ConvergenceWarning
)

return w, np.asarray(p_objs), stop_crit

@staticmethod
@njit
def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
primal_steps, dual_step, ws, max_epochs, tol_in):
n_features = X.shape[1]

for epoch in range(max_epochs):

for j in ws:
# update primal
old_w_j = w[j]
pseudo_grad = X[:, j] @ (2 * z_bar - z)
w[j] = penalty.prox_1d(
old_w_j - primal_steps[j] * pseudo_grad,
primal_steps[j], j)

# keep Xw syncr with X @ w
delta_w_j = w[j] - old_w_j
if delta_w_j:
Xw += delta_w_j * X[:, j]

# update dual
z_bar[:] = datafit.prox_conjugate(z + dual_step * Xw,
dual_step, y)
z += (z_bar - z) / n_features

# check convergence
if epoch % 10 == 0:
opts_primal_in = _scores_primal(
X, w, z, penalty, primal_steps, ws)

opt_dual_in = _score_dual(
y, z, Xw, datafit, dual_step)

stop_crit_in = max(
max(opts_primal_in),
opt_dual_in
)

if stop_crit_in <= tol_in:
break

@staticmethod
def _validate_init(datafit_, penalty_):
# validate datafit
missing_attrs = []
for attr in ('prox_conjugate', 'subdiff_distance'):
if not hasattr(datafit_, attr):
missing_attrs.append(f"`{attr}`")

if len(missing_attrs):
raise AttributeError(
"Datafit is not compatible with PDCD_WS solver.\n"
"Datafit must implement `prox_conjugate` and `subdiff_distance`.\n"
f"Missing {' and '.join(missing_attrs)}."
)

# jit compile classes
compiled_datafit = compiled_clone(datafit_)
compiled_penalty = compiled_clone(penalty_)

return compiled_datafit, compiled_penalty


@njit
def _scores_primal(X, w, z, penalty, primal_steps, ws):
scores_ws = np.zeros(len(ws))

for idx, j in enumerate(ws):
next_w_j = penalty.prox_1d(w[j] - primal_steps[j] * X[:, j] @ z,
primal_steps[j], j)
scores_ws[idx] = abs(w[j] - next_w_j)

return scores_ws


@njit
def _score_dual(y, z, Xw, datafit, dual_step):
next_z = datafit.prox_conjugate(z + dual_step * Xw,
dual_step, y)
return norm(z - next_z, ord=np.inf)
73 changes: 73 additions & 0 deletions skglm/experimental/quantile_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np
from numba import float64
from skglm.datafits import BaseDatafit
from skglm.utils.prox_funcs import ST_vec


class Pinball(BaseDatafit):
r"""Pinball datafit.

The datafit reads::

quantile * max(y - Xw, 0) + (1 - quantile) * max(Xw - y, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: the real value does not involve np.max, it is np.maximum(...). sum().
Maybe rewrite as a sum and use _i to denote sample indices ? check how sklearn does it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked up scikit-learn source code, but they don't specify the expression.

I am more with the usage of sum and _ to indicate samples.


such that quantile in ``[0, 1]``.
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
quantile : float
Quantile must be in ``[0, 1]``. When ``quantile=0.5``,
the datafit becomes a Least Absolute Deviation (LAD) datafit.
"""

def __init__(self, quantile):
self.quantile = quantile

def value(self, y, w, Xw):
# implementation taken from
# github.com/benchopt/benchmark_quantile_regression/blob/main/objective.py
quantile = self.quantile

residual = y - Xw
sign = residual >= 0

loss = quantile * sign * residual - (1 - quantile) * (1 - sign) * residual
return np.sum(loss)

def prox(self, w, step, y):
"""Prox of ||y - . || with step ``step``."""
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
shift_cst = (self.quantile - 1/2) * step
return y - ST_vec(y - w - shift_cst, step / 2)

def prox_conjugate(self, z, step, y):
"""Prox of ||y - . ||^* with step ``step``."""
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
# using Moreau decomposition
inv_step = 1 / step
return z - step * self.prox(inv_step * z, inv_step, y)

def subdiff_distance(self, Xw, z, y):
"""Distance of ``z`` to subdiff of ||y - . ||_1 at ``Xw``."""
# computation note: \partial ||y - . ||_1(Xw) = -\partial || . ||_1(y - Xw)
y_minus_Xw = y - Xw
shift_cst = self.quantile - 1/2

max_distance = 0.
for i in range(len(y)):

if y_minus_Xw[i] == 0.:
distance_i = max(0, abs(z[i] - shift_cst) - 1)
else:
distance_i = abs(z[i] + shift_cst + np.sign(y_minus_Xw[i]))

max_distance = max(max_distance, distance_i)

return max_distance

def get_spec(self):
spec = (
('quantile', float64),
)
return spec

def params_to_dict(self):
return dict(quantile=self.quantile)
20 changes: 19 additions & 1 deletion skglm/experimental/sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model._base import LinearModel, RegressorMixin

from skglm.penalties import L1
from skglm.utils.prox_funcs import ST_vec, proj_L2ball
from skglm.utils.prox_funcs import ST_vec, proj_L2ball, BST
from skglm.utils.jit_compilation import compiled_clone
from skglm.datafits.base import BaseDatafit
from skglm.solvers.prox_newton import ProxNewton
Expand Down Expand Up @@ -54,6 +54,24 @@ def raw_hessian(self, y, Xw):
fill_value = 1 / norm(y - Xw)
return np.full(n_samples, fill_value)

def prox(self, w, step, y):
"""Prox of ||y - . || with step."""
return y - BST(y - w, step)

def prox_conjugate(self, z, step, y):
"""Prox of ||y - . ||^* with step `step`."""
return proj_L2ball(z - step * y)

def subdiff_distance(self, Xw, z, y):
"""Distance of ``z`` to subdiff of ||y - . || at ``Xw``."""
# computation note: \partial ||y - . ||(Xw) = - \partial || . ||(y - Xw)
y_minus_Xw = y - Xw

if np.any(y_minus_Xw):
return norm(z + y_minus_Xw / norm(y_minus_Xw))

return norm(z - proj_L2ball(z))


class SqrtLasso(LinearModel, RegressorMixin):
"""Square root Lasso estimator based on Prox Newton solver.
Expand Down
38 changes: 38 additions & 0 deletions skglm/experimental/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import numpy as np
from numpy.linalg import norm

from skglm.penalties import L1
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.experimental.quantile_regression import Pinball

from skglm.utils.data import make_correlated_data
from sklearn.linear_model import QuantileRegressor


@pytest.mark.parametrize('quantile', [0.3, 0.5, 0.7])
def test_PDCD_WS(quantile):
n_samples, n_features = 50, 10
X, y, _ = make_correlated_data(n_samples, n_features, random_state=123)

# optimality condition for w = 0.
# for all g in subdiff pinball(y), g must be in subdiff ||.||_1(0)
# hint: use max(x, 0) = (x + |x|) / 2 to get subdiff pinball
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile - 0.5)), ord=np.inf)
alpha = alpha_max / 5

w = PDCD_WS(
dual_init=np.sign(y)/2 + (quantile - 0.5)
).solve(X, y, Pinball(quantile), L1(alpha))[0]

clf = QuantileRegressor(
quantile=quantile,
alpha=alpha/n_samples,
fit_intercept=False
).fit(X, y)

np.testing.assert_allclose(w, clf.coef_, atol=1e-5)


if __name__ == '__main__':
pass
Loading