You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am interested in domain generalization (DG, also "external validity) of statistical / machine learning models.
Anchor regression [1] is a recent idea interpolating between OLS and IV. [3] give ideas to generalize anchor regression to more general distributions (including classification). [2] is a "nice-to-read" summary, including ideas on how to extend to non-linear settings.
To my knowledge, no efficient implementations for anchor regression or classification exist. I'd be interested to contribute this to my favorite GLM library but would need some guidance.
What is Anchor Regression?
Anchor regression improves the DG / external validity of OLS by adding a regularization term penalizing the correlation between a so-called anchor variable and the regression's residuals. The anchor variable is assumed to be exogenous to the system, i.e., not directly causally affected by covariates, the outcome, or relevant hidden variables. See the following causal graph:
graph LR
A --> U & X & Y
U --> X & Y
X --> Y
Loading
What is an anchor?: Say we are interested to predict health outcomes in the ICU. Possibly valid anchor variables would be hospital id (one-hot encoded) or some transformation of time of year. The choice of anchor depends on the application. If we would like to predict out of time but on the same hospitals as seen in training, using time of year as anchor suffices. The hospital id should be included in the covariates (X). If we however would like to generalize across hospitals (i.e., predict on unseen hospitals), we need to include hospital id as an anchor (and exclude it from covariates). A similar example would be insurance with geographical location and time of year.
Write $P_A$ for the $\ell_2$-projection onto the column-space of $A$ (i.e., $P_A(\cdot) = \mathbb{E}[\cdot \mid A]$) and let $\gamma>0$. In a regression setting, the anchor regression solution is given by:
[1] show that the anchor regression solution protects against the worst-case risk with respect to distribution shifts induced through the anchor variable. Here $\gamma$ controls the size of the set of distributions the method protects against, which is generated by $\sqrt{\gamma}$-times the shifts as seen in the training data [1, Theorem 1].
In an instrumental variable (IV) setting (no direct causal effect $A \to U$, $A \to Y$, "sufficient" effect $A \to X$), anchor regression interpolates between OLS and IV regression, with $\hat b^\gamma$ converging to the IV solution for $\gamma \to \infty$. This is because the IV solution can be written as
Here is some numpy code calculating and testing the above derivatives:
importnumpyasnpimportpytestfromscipy.optimizeimportapprox_fprimedefpredictions(f):
return1/ (1+np.exp(-f))
defproj(A, f):
returnnp.dot(A, np.linalg.lstsq(A, f, rcond=None)[0])
defproj_matrix(A):
returnnp.dot(np.dot(A, np.linalg.inv(A.T @ A)), A.T)
defloss(X, beta, y, A, gamma):
f=X @ betar= (y/2+0.5) -predictions(f)
return-np.sum(np.log1p(np.exp(-y*f))) + (gamma-1) *np.sum(proj(A, r) **2)
defgrad(X, beta, y, A, gamma):
f=X @ betap=predictions(f)
r= (y/2+0.5) -preturn (r-2* (gamma-1) *proj(A, r) *p* (1-p)) @ Xdefhess(X, beta, y, A, gamma):
f=X @ betap=predictions(f)
r= (y/2+0.5) -pdiag=-np.diag(p* (1-p) * (1+2* (gamma-1) * (1-2*p) *proj(A, r)))
dense=proj_matrix(A) *p* (1-p)[np.newaxis, :] * (p* (1-p))[:, np.newaxis]
returnX.T @ (diag+2* (gamma-1) *dense) @ X@pytest.mark.parametrize("gamma", [0, 0.1, 0.8, 1, 5])deftest_grad_hess(gamma):
rng=np.random.default_rng(0)
n=100p=10q=3X=rng.normal(size=(n, p))
beta=rng.normal(size=p)
y=2*rng.binomial(1, 0.5, n) -1A=rng.normal(size=(n, q))
approx_grad=approx_fprime(beta, lambdab: loss(X, b, y, A, gamma))
np.testing.assert_allclose(approx_grad, grad(X, beta, y, A, gamma), 1e-5)
approx_hess=approx_fprime(beta, lambdab: grad(X, b, y, A, gamma), 1e-7)
np.testing.assert_allclose(approx_hess, hess(X, beta, y, A, gamma), 1e-5)
I understand that glum implements different solvers. As $\ell_1$-regularization is popular in the robustness community, the irls solver is most interesting.
To my understanding, the computation of the full projection matrix above can be skipped using a QR decomposition of $A$. However, in your implementation, you never actually compute the Hessian, but rather an approximation. And your implementation appears to depend heavily on the Hessian being of the form $X^T D X$ for some diagonal $D$, which is no longer the case here.
Summary
Anchor regression interpolates between OLS and IV regression to improve the models' robustness to distribution shifts.
Distributional anchor regression is a generalization to GLMs. To my knowledge, no efficient solver for distributional anchor regression exists.
Is this something you would be interested to integrate into glum? How complex would this be? Are there any hurdles (e.g., dense Hessian) that prohibit the use of existing methods?
References
[1] Rothenhäusler, D., N. Meinshausen, P. Bühlmann, and J. Peters (2021). Anchor regression: Heterogeneous data meet causality. Journal of the Royal Statistical Society Series B (Statistical Methodology) 83(2), 215–246.
[2] Bühlmann, P. (2020). Invariance, causality and robustness. Statistical Science 35(3), 404– 426.
[3] Kook, L., B. Sick, and P. Bühlmann (2022). Distributional anchor regression. Statistics and Computing 32(3), 1–19.
The text was updated successfully, but these errors were encountered:
I am interested in domain generalization (DG, also "external validity) of statistical / machine learning models.
Anchor regression [1] is a recent idea interpolating between OLS and IV. [3] give ideas to generalize anchor regression to more general distributions (including classification). [2] is a "nice-to-read" summary, including ideas on how to extend to non-linear settings.
To my knowledge, no efficient implementations for anchor regression or classification exist. I'd be interested to contribute this to my favorite GLM library but would need some guidance.
What is Anchor Regression?
Anchor regression improves the DG / external validity of OLS by adding a regularization term penalizing the correlation between a so-called anchor variable and the regression's residuals. The anchor variable is assumed to be exogenous to the system, i.e., not directly causally affected by covariates, the outcome, or relevant hidden variables. See the following causal graph:
What is an anchor?: Say we are interested to predict health outcomes in the ICU. Possibly valid anchor variables would be hospital id (one-hot encoded) or some transformation of time of year. The choice of anchor depends on the application. If we would like to predict out of time but on the same hospitals as seen in training, using time of year as anchor suffices. The hospital id should be included in the covariates (X). If we however would like to generalize across hospitals (i.e., predict on unseen hospitals), we need to include hospital id as an anchor (and exclude it from covariates). A similar example would be insurance with geographical location and time of year.
Write$P_A$ for the $\ell_2$ -projection onto the column-space of $A$ (i.e., $P_A(\cdot) = \mathbb{E}[\cdot \mid A]$ ) and let $\gamma>0$ . In a regression setting, the anchor regression solution is given by:
$$
b^\gamma = \underset{b}{\arg\min} \mathbb{E}\textrm{train}[((\mathrm{Id} - P_A)(Y - X^T b))^2] + \gamma \mathbb{E}\textrm{train}[(P_A(Y - X^T b))^2].
$$
Given samples from$P_\mathrm{train}$ , write $\Pi_A$ for the projection onto the column space of $A$ , this can be estimated as
[1] show that the anchor regression solution protects against the worst-case risk with respect to distribution shifts induced through the anchor variable. Here$\gamma$ controls the size of the set of distributions the method protects against, which is generated by $\sqrt{\gamma}$ -times the shifts as seen in the training data [1, Theorem 1].
In an instrumental variable (IV) setting (no direct causal effect$A \to U$ , $A \to Y$ , "sufficient" effect $A \to X$ ), anchor regression interpolates between OLS and IV regression, with $\hat b^\gamma$ converging to the IV solution for $\gamma \to \infty$ . This is because the IV solution can be written as
In low-dimensional settings, (1) can be optimized using the transformation
where$\Pi_A = A (A^T A)^{-1} A^T$ (this needs not to be calculated though).
What is Distributional Anchor Regression?
[2] present ideas on how to generalize anchor regression from OLS to GLMs. In particular, if$f$ are raw scores, they propose to use residuals
For$f = X^T \beta$ and $\ell(f, y) = \frac{1}{2}(y - f)^2$ this reduces to anchor regression. For logistic regression, with $Y \in {-1, 1}$ and
this yields residuals
where$\tilde y = \frac{y}{2} + 0.5 \in {0, 1}$ and $p_i = (1 + \exp(-f_i))^{-1}$ .
Define$\ell^\gamma(y, f) := \ell(f, y) + (\gamma - 1) | \Pi_A r |_2^2$ . The gradient of the anchor loss is given as
The Hessian is (not pretty)
$$
\frac{d}{d f_i f_j}
\ell^\gamma(f, y) = -\mathbb{1}_{{i = j}} p_i ( 1 - p_i) \left(1 + 2(\gamma - 1) (1 - 2p_i) (\Pi_A r)i \right) + 2 (\gamma - 1) p_i (1 - p_i) p_j (1 - p_j) (\Pi_A){i, j}
$$
If$f = X^T \beta$ , then (here, $\cdot$ is matrix multiplication)
and
Computational considerations
Here is some numpy code calculating and testing the above derivatives:
I understand that$\ell_1$ -regularization is popular in the robustness community, the
glum
implements different solvers. Asirls
solver is most interesting.To my understanding, the computation of the full projection matrix above can be skipped using a QR decomposition of$A$ . However, in your implementation, you never actually compute the Hessian, but rather an approximation. And your implementation appears to depend heavily on the Hessian being of the form $X^T D X$ for some diagonal $D$ , which is no longer the case here.
Summary
Anchor regression interpolates between OLS and IV regression to improve the models' robustness to distribution shifts.
Distributional anchor regression is a generalization to GLMs. To my knowledge, no efficient solver for distributional anchor regression exists.
Is this something you would be interested to integrate into
glum
? How complex would this be? Are there any hurdles (e.g., dense Hessian) that prohibit the use of existing methods?References
[1] Rothenhäusler, D., N. Meinshausen, P. Bühlmann, and J. Peters (2021). Anchor regression: Heterogeneous data meet causality. Journal of the Royal Statistical Society Series B (Statistical Methodology) 83(2), 215–246.
[2] Bühlmann, P. (2020). Invariance, causality and robustness. Statistical Science 35(3), 404– 426.
[3] Kook, L., B. Sick, and P. Bühlmann (2022). Distributional anchor regression. Statistics and Computing 32(3), 1–19.
The text was updated successfully, but these errors were encountered: