Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
xunzheng committed Nov 6, 2019
1 parent dd710f4 commit 1453c49
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions notears.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ def _loss(W):
if loss_type == 'l2':
R = X - M
loss = 0.5 / X.shape[0] * (R ** 2).sum()
D = - 1.0 / X.shape[0] * X.T @ R
G_loss = - 1.0 / X.shape[0] * X.T @ R
elif loss_type == 'logistic':
loss = 1.0 / X.shape[0] * (np.logaddexp(0, M) - X * M).sum()
D = 1.0 / X.shape[0] * X.T @ (sigmoid(M) - X)
G_loss = 1.0 / X.shape[0] * X.T @ (sigmoid(M) - X)
elif loss_type == 'poisson':
S = np.exp(M)
loss = 1.0 / X.shape[0] * (S - X * M).sum()
D = 1.0 / X.shape[0] * X.T @ (S - X)
G_loss = 1.0 / X.shape[0] * X.T @ (S - X)
else:
raise ValueError('unknown loss type')
return loss, D
return loss, G_loss

def _h(W):
"""Evaluate value and gradient of acyclicity constraint."""
Expand All @@ -44,7 +44,8 @@ def _h(W):
M = np.eye(d) + W * W / d # (Yu et al. 2019)
E = np.linalg.matrix_power(M, d - 1)
h = (E.T * M).sum() - d
return h, E
G_h = E.T * W * 2
return h, G_h

def _adj(w):
"""Convert doubled variables ([2 d^2] array) back to original variables ([d, d] matrix)."""
Expand All @@ -53,12 +54,12 @@ def _adj(w):
def _func(w):
"""Evaluate value and gradient of augmented Lagrangian for doubled variables ([2 d^2] array)."""
W = _adj(w)
loss, D = _loss(W)
h, E = _h(W)
loss, G_loss = _loss(W)
h, G_h = _h(W)
obj = loss + 0.5 * rho * h * h + alpha * h + lambda1 * w.sum()
G = D + (rho * h + alpha) * E.T * W * 2
grad_cat = np.concatenate((G + lambda1, - G + lambda1), axis=None)
return obj, grad_cat
G_smooth = G_loss + (rho * h + alpha) * G_h
g_obj = np.concatenate((G_smooth + lambda1, - G_smooth + lambda1), axis=None)
return obj, g_obj

n, d = X.shape
w_est, rho, alpha, h = np.zeros(2 * d * d), 1.0, 0.0, np.inf # double w_est into (w_pos, w_neg)
Expand Down

0 comments on commit 1453c49

Please sign in to comment.